diff options
97 files changed, 417 insertions, 267 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index e97f601fb74..26e9fa8a52a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,6 +81,7 @@ add_subdirectory(messagebus) add_subdirectory(messagebus_test) add_subdirectory(metrics) add_subdirectory(model-evaluation) +add_subdirectory(model-integration) add_subdirectory(node-repository) add_subdirectory(orchestrator) add_subdirectory(persistence) diff --git a/config-model/pom.xml b/config-model/pom.xml index a65a6a836ed..b377760249a 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -293,16 +293,10 @@ <scope>test</scope> </dependency> <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> + <groupId>com.yahoo.vespa</groupId> + <artifactId>model-integration</artifactId> + <version>${project.version}</version> + <scope>test</scope> </dependency> </dependencies> diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index 574c25a2f84..d784722f77a 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -21,8 +21,9 @@ import com.yahoo.config.provision.Zone; import com.yahoo.io.reader.NamedReader; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.SearchBuilder; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionBuilder; import com.yahoo.vespa.config.ConfigDefinitionKey; @@ -39,6 +40,7 @@ import java.io.IOException; import java.io.Reader; import java.time.Instant; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; @@ -80,11 +82,23 @@ public class DeployState implements ConfigDefinitionStore { return new Builder().applicationPackage(applicationPackage).build(); } - private DeployState(ApplicationPackage applicationPackage, SearchDocumentModel searchDocumentModel, RankProfileRegistry rankProfileRegistry, - FileRegistry fileRegistry, DeployLogger deployLogger, Optional<HostProvisioner> hostProvisioner, DeployProperties properties, - Optional<ApplicationPackage> permanentApplicationPackage, Optional<ConfigDefinitionRepo> configDefinitionRepo, - java.util.Optional<Model> previousModel, Set<Rotation> rotations, Zone zone, QueryProfiles queryProfiles, - SemanticRules semanticRules, Instant now, Version wantedNodeVespaVersion) { + private DeployState(ApplicationPackage applicationPackage, + SearchDocumentModel searchDocumentModel, + RankProfileRegistry rankProfileRegistry, + FileRegistry fileRegistry, + DeployLogger deployLogger, + Optional<HostProvisioner> hostProvisioner, + DeployProperties properties, + Optional<ApplicationPackage> permanentApplicationPackage, + Optional<ConfigDefinitionRepo> configDefinitionRepo, + java.util.Optional<Model> previousModel, + Set<Rotation> rotations, + Collection<ModelImporter> modelImporters, + Zone zone, + QueryProfiles queryProfiles, + SemanticRules semanticRules, + Instant now, + Version wantedNodeVespaVersion) { this.logger = deployLogger; this.fileRegistry = fileRegistry; this.rankProfileRegistry = rankProfileRegistry; @@ -100,7 +114,8 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated - this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR)); + this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR), + modelImporters); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty); this.wantedNodeVespaVersion = wantedNodeVespaVersion; @@ -232,6 +247,7 @@ public class DeployState implements ConfigDefinitionStore { private Optional<ConfigDefinitionRepo> configDefinitionRepo = Optional.empty(); private Optional<Model> previousModel = Optional.empty(); private Set<Rotation> rotations = new HashSet<>(); + private Collection<ModelImporter> modelImporters = Collections.emptyList(); private Zone zone = Zone.defaultZone(); private Instant now = Instant.now(); private Version wantedNodeVespaVersion = Vtag.currentVersion; @@ -281,6 +297,11 @@ public class DeployState implements ConfigDefinitionStore { return this; } + public Builder modelImporters(Collection<ModelImporter> modelImporters) { + this.modelImporters = modelImporters; + return this; + } + public Builder zone(Zone zone) { this.zone = zone; return this; @@ -305,9 +326,23 @@ public class DeployState implements ConfigDefinitionStore { QueryProfiles queryProfiles = new QueryProfilesBuilder().build(applicationPackage); SemanticRules semanticRules = new SemanticRuleBuilder().build(applicationPackage); SearchDocumentModel searchDocumentModel = createSearchDocumentModel(rankProfileRegistry, logger, queryProfiles, validationParameters); - return new DeployState(applicationPackage, searchDocumentModel, rankProfileRegistry, fileRegistry, logger, hostProvisioner, - properties, permanentApplicationPackage, configDefinitionRepo, previousModel, rotations, - zone, queryProfiles, semanticRules, now, wantedNodeVespaVersion); + return new DeployState(applicationPackage, + searchDocumentModel, + rankProfileRegistry, + fileRegistry, + logger, + hostProvisioner, + properties, + permanentApplicationPackage, + configDefinitionRepo, + previousModel, + rotations, + modelImporters, + zone, + queryProfiles, + semanticRules, + now, + wantedNodeVespaVersion); } private SearchDocumentModel createSearchDocumentModel(RankProfileRegistry rankProfileRegistry, diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 78f61d7192d..a0cbc4271f3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -16,7 +16,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 0c5c7733dda..eb221a80638 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -12,7 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.derived.validation.Validation; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import java.io.IOException; import java.io.Writer; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index fcbfb47c597..c955cd9956b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.RankingConstants; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index a1b0e72051b..fd468c8eca7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -9,7 +9,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 2fe2dacf2ce..f52eedc8d08 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import java.util.HashMap; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index fef198f7939..b4de5e925f6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -33,8 +33,8 @@ import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.searchdefinition.processing.Processing; import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.ml.ConvertedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigKey; import com.yahoo.vespa.config.ConfigPayload; @@ -150,7 +150,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri this(configModelRegistry, deployState, true, null); } - private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) throws IOException, SAXException { + private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) + throws IOException, SAXException { super("vespamodel"); this.validationOverrides = deployState.validationOverrides(); configModelRegistry = new VespaConfigModelRegistry(configModelRegistry); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java index 2af9b297e9e..ff64b6753f9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java @@ -21,6 +21,7 @@ import com.yahoo.config.model.deploy.DeployProperties; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.provision.Version; import com.yahoo.config.provision.Zone; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; import com.yahoo.vespa.config.VespaVersion; import com.yahoo.vespa.model.application.validation.Validation; @@ -29,6 +30,8 @@ import org.xml.sax.SAXException; import java.io.IOException; import java.time.Clock; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.logging.Logger; @@ -41,19 +44,17 @@ public class VespaModelFactory implements ModelFactory { private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName()); private final ConfigModelRegistry configModelRegistry; + private final Collection<ModelImporter> modelImporters; private final Zone zone; private final Clock clock; private final Version version; /** Creates a factory for vespa models for this version of the source */ @Inject - public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) { - this(Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), pluginRegistry, zone); - } - - /** Creates a factory for vespa models of a particular version */ - public VespaModelFactory(Version version, ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) { - this.version = version; + public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, + ComponentRegistry<ModelImporter> modelImporters, + Zone zone) { + this.version = Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro); List<ConfigModelBuilder> modelBuilders = new ArrayList<>(); for (ConfigModelPlugin plugin : pluginRegistry.allComponents()) { if (plugin instanceof ConfigModelBuilder) { @@ -61,6 +62,7 @@ public class VespaModelFactory implements ModelFactory { } } this.configModelRegistry = new MapConfigModelRegistry(modelBuilders); + this.modelImporters = modelImporters.allComponents(); this.zone = zone; this.clock = Clock.systemUTC(); } @@ -79,6 +81,7 @@ public class VespaModelFactory implements ModelFactory { } else { this.configModelRegistry = configModelRegistry; } + this.modelImporters = Collections.emptyList(); this.zone = Zone.defaultZone(); this.clock = clock; } @@ -137,6 +140,7 @@ public class VespaModelFactory implements ModelFactory { .properties(createDeployProperties(modelContext.properties())) .modelHostProvisioner(createHostProvisioner(modelContext)) .rotations(modelContext.properties().rotations()) + .modelImporters(modelImporters) .zone(zone) .now(clock.instant()) .wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 30586b1e677..18f3cd6e088 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -19,7 +19,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -41,7 +41,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; import java.io.IOException; -import java.io.Reader; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java index 07a36832094..ad1676ee7e1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 4df3add13c5..14847938f68 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.Iterator; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java index 8df3985fd24..4f46d813ba5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index 150469cc928..551ef6968d2 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -3,7 +3,7 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index e507a6c48e4..79adb1cf6bf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -6,8 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import com.yahoo.yolean.Exceptions; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.Optional; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 6a1e5b207c6..d97dfb0cbf9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.List; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java index 5e649c2e551..dedfbce6ae8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java @@ -4,9 +4,8 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java index f39896f5779..437f43976b7 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java @@ -3,12 +3,11 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.document.DocumenttypesConfig; import com.yahoo.document.config.DocumentmanagerConfig; -import com.yahoo.io.IOUtils; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.configmodel.producers.DocumentManager; import com.yahoo.vespa.configmodel.producers.DocumentTypes; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java index f4344c9b03c..409b4236a9c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java @@ -9,12 +9,9 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; -import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; -import java.io.IOException; - /** * Tests deriving rank for files from search definitions * diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java index 2bae285301c..f2dd16577d1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java @@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java index 723cd58a34a..93b366c8d08 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.io.File; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java index 8941b07101d..3a44c123f05 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java @@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java index d38bce04617..1c368ff6f10 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java @@ -6,13 +6,11 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.DerivedConfiguration; -import com.yahoo.searchdefinition.derived.Deriver; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import org.junit.Ignore; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; -import java.io.File; + import java.io.IOException; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index cff9abb08ed..7f590d4b230 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.google.common.collect.ImmutableList; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.path.Path; @@ -10,7 +11,11 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; @@ -26,6 +31,9 @@ import static org.junit.Assert.assertEquals; */ class RankProfileSearchFixture { + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; private Search search; @@ -83,7 +91,8 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) - .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); + .compile(queryProfileRegistry, + new ImportedModels(applicationDir.toFile(), importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index fd048737b43..78dcae53f7e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.io.IOException; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 6e3a227e2a9..7eee18c1059 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.List; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 2ae629562d0..4d8bdb7a6c6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -1,11 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import com.google.common.collect.ImmutableList; import com.yahoo.config.model.ApplicationPackageTester; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.model.VespaModel; @@ -26,6 +32,10 @@ import static org.junit.Assert.assertEquals; */ public class ImportedModelTester { + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); + private final String modelName; private final Path applicationDir; @@ -36,7 +46,10 @@ public class ImportedModelTester { public VespaModel createVespaModel() { try { - return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); + DeployState.Builder state = new DeployState.Builder(); + state.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app()); + state.modelImporters(importers); + return new VespaModel(state.build()); } catch (SAXException | IOException e) { throw new RuntimeException(e); diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 5f60be8c202..3dd6e0090c5 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -55,6 +55,7 @@ <preprocess:include file='config-models.xml' required='false' /> <preprocess:include file='node-repository.xml' required='false' /> <preprocess:include file='hosted-vespa/routing-status.xml' required='false' /> + <preprocess:include file='model-integration.xml' required='true' /> <!-- TODO Vespa 7: Remove scoreboard.xml, replaced by metrics-packets.xml --> <preprocess:include file='hosted-vespa/scoreboard.xml' required='false' /> diff --git a/container-dev/pom.xml b/container-dev/pom.xml index 1b0f36e3adb..7d61882c085 100644 --- a/container-dev/pom.xml +++ b/container-dev/pom.xml @@ -104,18 +104,6 @@ <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> <dependency> @@ -178,18 +166,6 @@ <groupId>xerces</groupId> <artifactId>xercesImpl</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> <dependency> diff --git a/container-disc/pom.xml b/container-disc/pom.xml index 62985192ad3..be4ac23a938 100644 --- a/container-disc/pom.xml +++ b/container-disc/pom.xml @@ -174,6 +174,7 @@ jdisc-security-filters-jar-with-dependencies.jar, jdisc_http_service-jar-with-dependencies.jar, model-evaluation-jar-with-dependencies.jar, + model-integration-jar-with-dependencies.jar, vespaclient-container-plugin-jar-with-dependencies.jar, vespa-athenz-jar-with-dependencies.jar, security-utils-jar-with-dependencies.jar, diff --git a/container/pom.xml b/container/pom.xml index 32a7947d6d5..529f01e0a40 100644 --- a/container/pom.xml +++ b/container/pom.xml @@ -53,18 +53,6 @@ <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </exclusion> - <exclusion> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </exclusion> - <exclusion> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </exclusion> </exclusions> </dependency> </dependencies> diff --git a/document/src/main/java/com/yahoo/document/Document.java b/document/src/main/java/com/yahoo/document/Document.java index 85049baf8b0..222ebe29c6d 100644 --- a/document/src/main/java/com/yahoo/document/Document.java +++ b/document/src/main/java/com/yahoo/document/Document.java @@ -286,6 +286,7 @@ public class Document extends StructuredFieldValue { /** * Get JSON representation of the document root and its children contained in a JSON object + * * @return JSON representation of document */ public String toJson() { diff --git a/model-integration/pom.xml b/model-integration/pom.xml new file mode 100644 index 00000000000..28a00dcbdbc --- /dev/null +++ b/model-integration/pom.xml @@ -0,0 +1,100 @@ +<?xml version="1.0"?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 + http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>com.yahoo.vespa</groupId> + <artifactId>parent</artifactId> + <version>6-SNAPSHOT</version> + <relativePath>../parent/pom.xml</relativePath> + </parent> + <artifactId>model-integration</artifactId> + <version>6-SNAPSHOT</version> + <packaging>container-plugin</packaging> + <dependencies> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>component</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>vespajlib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>searchlib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <scope>provided</scope> + </dependency> + + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </dependency> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + </dependency> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>tensorflow</artifactId> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>com.yahoo.vespa</groupId> + <artifactId>bundle-plugin</artifactId> + <extensions>true</extensions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <configuration> + <compilerArgs> + <arg>-Xlint:rawtypes</arg> + <arg>-Xlint:unchecked</arg> + <arg>-Werror</arg> + </compilerArgs> + </configuration> + </plugin> + <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + <version>3.5.1.1</version> + <executions> + <execution> + <phase>generate-sources</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <addSources>main</addSources> + <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory> + <inputDirectories> + <include>src/main/protobuf</include> + </inputDirectories> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> +</project> diff --git a/model-integration/src/main/config/model-integration.xml b/model-integration/src/main/config/model-integration.xml new file mode 100644 index 00000000000..da45ce23575 --- /dev/null +++ b/model-integration/src/main/config/model-integration.xml @@ -0,0 +1,10 @@ +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<!-- Component which can import some ml model. + This is included into the config server services.xml to enable it to translate + model pseudofeatures in ranking expressions during config mddel building. + It is provided as separate bundles instead of being included in the config model + because some of these (TensorFlow) includes + JNI code, and so can only exist in one instance in the server. --> +<component id="ai.vespa.rankingexpression.importer.onnx.OnnxImporter" bundle="model-integration" /> +<component id="ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter" bundle="model-integration" /> +<component id="ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter" bundle="model-integration" /> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 3fe92440cae..81d0753ea4b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; +package ai.vespa.rankingexpression.importer.onnx; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; @@ -28,9 +28,9 @@ import java.util.stream.Collectors; * * @author lesters */ -public class GraphImporter { +class GraphImporter { - public static IntermediateOperation mapOperation(Onnx.NodeProto node, + private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, IntermediateGraph graph) { String nodeName = node.getName(); @@ -79,7 +79,7 @@ public class GraphImporter { return op; } - public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { + static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { Onnx.GraphProto onnxGraph = model.getGraph(); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java index e6bb5f40b3f..0418581d7b2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.onnx; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; import onnx.Onnx; import java.io.File; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index 18856d4a25f..6dd33c79852 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; +package ai.vespa.rankingexpression.importer.onnx; import com.google.protobuf.ByteString; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; @@ -17,9 +17,9 @@ import java.nio.FloatBuffer; * * @author lesters */ -public class TensorConverter { +class TensorConverter { - public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) { + static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) { Values values = readValuesOf(tensorProto); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); for (int i = 0; i < values.size(); i++) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 715c55d8323..43ceaa747b7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; +package ai.vespa.rankingexpression.importer.onnx; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; @@ -11,9 +11,9 @@ import onnx.Onnx; * * @author lesters */ -public class TypeConverter { +class TypeConverter { - public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { + static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); if (shape != null) { if (shape.getDimCount() != type.rank()) { @@ -30,11 +30,11 @@ public class TypeConverter { } } - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { + static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... } - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { Onnx.TensorShapeProto shape = type.getTensorType().getShape(); OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); for (int i = 0; i < shape.getDimCount(); ++ i) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java new file mode 100644 index 00000000000..9599cf8627c --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java @@ -0,0 +1,5 @@ +@ExportPackage + +package ai.vespa.rankingexpression.importer.onnx; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java index 89b75e8e3e2..73310c78cab 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -20,15 +20,15 @@ import java.util.stream.Collectors; * * @author lesters */ -public class AttributeConverter implements IntermediateOperation.AttributeMap { +class AttributeConverter implements IntermediateOperation.AttributeMap { private final Map<String, AttrValue> attributeMap; - public AttributeConverter(NodeDef node) { + private AttributeConverter(NodeDef node) { attributeMap = node.getAttrMap(); } - public static AttributeConverter convert(NodeDef node) { + static AttributeConverter convert(NodeDef node) { return new AttributeConverter(node); } @@ -83,4 +83,5 @@ public class AttributeConverter implements IntermediateOperation.AttributeMap { } return Optional.empty(); } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index e1b292f9e61..c012bc3c54f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; @@ -43,9 +43,9 @@ import java.util.stream.Collectors; * * @author lesters */ -public class GraphImporter { +class GraphImporter { - public static IntermediateOperation mapOperation(NodeDef node, + private static IntermediateOperation mapOperation(NodeDef node, List<IntermediateOperation> inputs, IntermediateGraph graph) { String nodeName = node.getName(); @@ -112,7 +112,7 @@ public class GraphImporter { return op; } - public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { + static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef()); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); @@ -209,7 +209,7 @@ public class GraphImporter { throw new IllegalArgumentException("Could not find node '" + name + "'"); } - public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { + static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { Session.Runner fetched = bundle.session().runner().fetch(name); List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); if (importedTensors.size() != 1) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index d2d0acfc964..4e67286ef09 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; @@ -26,7 +26,7 @@ public class TensorConverter { return toVespaTensor(tfTensor, "d"); } - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); @@ -35,7 +35,7 @@ public class TensorConverter { return builder.build(); } - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { + static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); for (int i = 0; i < values.size(); i++) { @@ -44,7 +44,7 @@ public class TensorConverter { return builder.build(); } - public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { + static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); Values values = readValuesOf(tensorProto); for (int i = 0; i < values.size(); ++i) { @@ -71,7 +71,7 @@ public class TensorConverter { return size; } - public static Long dimensionSize(TensorType.Dimension dim) { + private static Long dimensionSize(TensorType.Dimension dim) { return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); } @@ -182,8 +182,8 @@ public class TensorConverter { } private static abstract class ProtoValues extends Values { - protected final TensorProto tensorProto; - protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } + final TensorProto tensorProto; + ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } } private static class ProtoBoolValues extends ProtoValues { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index 7c18e04bae7..f8a25e4d94c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -1,8 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; import org.tensorflow.SavedModelBundle; import java.io.File; @@ -47,7 +48,7 @@ public class TensorFlowImporter extends ModelImporter { } /** Imports a TensorFlow model */ - ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { + public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); return convertIntermediateGraphToModel(graph, modelDir); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index 67ad1edc312..a5a506fdb6d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; @@ -15,9 +15,9 @@ import java.util.List; * * @author lesters */ -public class TypeConverter { +class TypeConverter { - public static void verifyType(NodeDef node, OrderedTensorType type) { + static void verifyType(NodeDef node, OrderedTensorType type) { TensorShapeProto shape = tensorFlowShape(node); if (shape != null) { if (shape.getDimCount() != type.rank()) { @@ -50,11 +50,11 @@ public class TypeConverter { return shapeList.get(0); // support multiple outputs? } - public static OrderedTensorType fromTensorFlowType(NodeDef node) { + static OrderedTensorType fromTensorFlowType(NodeDef node) { return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... } - public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { + private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); TensorShapeProto shape = tensorFlowShape(node); for (int i = 0; i < shape.getDimCount(); ++ i) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java index 25bac27f315..b777ee07e58 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java @@ -1,9 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; @@ -16,7 +14,7 @@ import java.nio.charset.StandardCharsets; * * @author bratseth */ -public class VariableConverter { +class VariableConverter { /** * Reads the tensor with the given TensorFlow name at the given model location, @@ -24,7 +22,7 @@ public class VariableConverter { * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor * tensor dimensions are implicitly ordered. */ - public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { + static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, bundle), diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java new file mode 100644 index 00000000000..0840e584d25 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java @@ -0,0 +1,4 @@ +@ExportPackage +package ai.vespa.rankingexpression.importer.tensorflow; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index 725f319a839..e87dd265d50 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -1,8 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.xgboost; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost.XGBoostParser; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import java.io.File; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java index fef8bfec81d..2b215f816f5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost; +package ai.vespa.rankingexpression.importer.xgboost; import java.io.File; import java.io.IOException; @@ -13,7 +13,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; /** * @author grace-lam */ -public class XGBoostParser { +class XGBoostParser { private List<XGBoostTree> xgboostTrees; @@ -24,7 +24,7 @@ public class XGBoostParser { * @throws JsonProcessingException Fails JSON parsing. * @throws IOException Fails file reading. */ - public XGBoostParser(String filePath) throws JsonProcessingException, IOException { + XGBoostParser(String filePath) throws JsonProcessingException, IOException { this.xgboostTrees = new ArrayList<>(); ObjectMapper mapper = new ObjectMapper(); JsonNode forestNode = mapper.readTree(new File(filePath)); @@ -38,7 +38,7 @@ public class XGBoostParser { * * @return Vespa ranking expressions. */ - public String toRankingExpression() { + String toRankingExpression() { StringBuilder ret = new StringBuilder(); for (int i = 0; i < xgboostTrees.size(); i++) { ret.append(treeToRankExp(xgboostTrees.get(i))); @@ -55,7 +55,7 @@ public class XGBoostParser { * @param node XGBoost tree node to convert. * @return Vespa ranking expression for input node. */ - public String treeToRankExp(XGBoostTree node) { + private String treeToRankExp(XGBoostTree node) { if (node.isLeaf()) { return Double.toString(node.getLeaf()); } else { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java index 6bbc9abe8ae..e32e0f1eab5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost; +package ai.vespa.rankingexpression.importer.xgboost; import java.util.List; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java new file mode 100644 index 00000000000..d310de9041b --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java @@ -0,0 +1,4 @@ +@ExportPackage +package ai.vespa.rankingexpression.importer.xgboost; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/searchlib/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto index dc6542867e0..dc6542867e0 100644 --- a/searchlib/src/main/protobuf/onnx.proto +++ b/model-integration/src/main/protobuf/onnx.proto diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 6d0ee0a906d..a71d7a42551 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -1,11 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.onnx; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -21,7 +23,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); @@ -55,8 +57,8 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testComparisonBetweenOnnxAndTensorflow() { - String tfModelPath = "src/test/files/integration/tensorflow/mnist_softmax/saved"; - String onnxModelPath = "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"; + String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved"; + String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"; Tensor argument = placeholderArgument(); Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java index e325c3d11b4..acd649d985b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -16,7 +16,7 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", - "src/test/files/integration/tensorflow/batch_norm/saved"); + "src/test/models/tensorflow/batch_norm/saved"); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java index 5b0cca6a940..a878b284b2c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -17,8 +17,6 @@ import org.tensorflow.Session; import java.nio.FloatBuffer; import java.util.List; -import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTensorFlowModel.contextFrom; - /** * Microbenchmark of imported ML model evaluation. * @@ -26,13 +24,13 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTenso */ public class BlogEvaluationBenchmark { - static final String modelDir = "src/test/files/integration/tensorflow/blog/saved"; + static final String modelDir = "src/test/models/tensorflow/blog/saved"; public static void main(String[] args) throws ParseException { SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel); - Context context = contextFrom(model); + Context context = TestableTensorFlowModel.contextFrom(model); Tensor u = generateInputTensor(); Tensor d = generateInputTensor(); context.put("input_u", new TensorValue(u)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index 8ca5a9a7888..6e58761e5ce 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.tensor.TensorType; +import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -17,18 +18,18 @@ public class DropoutImportTestCase { @Test public void testDropoutImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/dropout/saved"); // Check required functions - assertEquals(1, model.get().inputs().size()); + Assert.assertEquals(1, model.get().inputs().size()); assertTrue(model.get().inputs().containsKey("X")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().inputs().get("X")); + Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().inputs().get("X")); ImportedModel.Signature signature = model.get().signature("serving_default"); - assertEquals("Has skipped outputs", - 0, model.get().signature("serving_default").skippedOutputs().size()); + Assert.assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); ExpressionFunction function = signature.outputExpression("y"); assertNotNull(function); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java index 3d8d5d5a570..b338f46fb4d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java @@ -1,8 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; +import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -15,11 +16,11 @@ public class MnistImportTestCase { @Test public void testMnistImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist/saved"); ImportedModel.Signature signature = model.get().signature("serving_default"); - assertEquals("Has skipped outputs", - 0, model.get().signature("serving_default").skippedOutputs().size()); + Assert.assertEquals("Has skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index feba40601e3..7e8cbef8ada 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -1,10 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import org.junit.Assert; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -18,10 +19,10 @@ public class TensorFlowMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, model.get().largeConstants().size()); + Assert.assertEquals(2, model.get().largeConstants().size()); Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); assertNotNull(constant0); @@ -36,16 +37,16 @@ public class TensorFlowMnistSoftmaxImportTestCase { assertEquals(10, constant1.size()); // Check (provided) functions - assertEquals(0, model.get().functions().size()); + Assert.assertEquals(0, model.get().functions().size()); // Check required functions - assertEquals(1, model.get().inputs().size()); + Assert.assertEquals(1, model.get().inputs().size()); assertTrue(model.get().inputs().containsKey("Placeholder")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().inputs().get("Placeholder")); + Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.get().inputs().get("Placeholder")); // Check signatures - assertEquals(1, model.get().signatures().size()); + Assert.assertEquals(1, model.get().signatures().size()); ImportedModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index fbe7c5fac63..dbed537885e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java index aabbdf33c9d..c9fffe143b4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package ai.vespa.rankingexpression.importer.tensorflow; import org.junit.Test; @@ -11,7 +11,7 @@ public class VariableConverterTestCase { @Test public void testConversion() { - byte[] converted = VariableConverter.importVariable("src/test/files/integration/tensorflow/mnist_softmax/saved", + byte[] converted = VariableConverter.importVariable("src/test/models/tensorflow/mnist_softmax/saved", "Variable_1", "tensor(d0[10],d1[1])"); assertEquals("{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"1\",\"d1\":\"0\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"2\",\"d1\":\"0\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"3\",\"d1\":\"0\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"4\",\"d1\":\"0\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"5\",\"d1\":\"0\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"6\",\"d1\":\"0\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"7\",\"d1\":\"0\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"8\",\"d1\":\"0\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"9\",\"d1\":\"0\"},\"value\":-0.2573896050453186}]}", diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java new file mode 100644 index 00000000000..48c7f5bee19 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -0,0 +1,28 @@ +package ai.vespa.rankingexpression.importer.xgboost; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * @author bratseth + */ +public class XGBoostImportTestCase { + + @Test + public void testXGBoost() { + ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json"); + assertTrue("All inputs are scalar", model.inputs().isEmpty()); + assertEquals(1, model.expressions().size()); + System.out.println(model.expressions().keySet()); + RankingExpression expression = model.expressions().get("test"); + assertNotNull(expression); + assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)", + expression.getRoot().toString()); + } + +} diff --git a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx Binary files differindex a86019bf53a..a86019bf53a 100644 --- a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx +++ b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py index bc6ea13ebc1..bc6ea13ebc1 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py +++ b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt index f3ce68a1cbd..f3ce68a1cbd 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 Binary files differindex 875e8361e10..875e8361e10 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index Binary files differindex 46c7b258cf5..46c7b258cf5 100644 --- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt index a669e69b709..a669e69b709 100644 --- a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001 Binary files differindex 1efd102aef9..1efd102aef9 100644 --- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index Binary files differindex 56c60dbe529..56c60dbe529 100644 --- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py b/model-integration/src/test/models/tensorflow/dropout/dropout.py index 42c15cd2812..42c15cd2812 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py +++ b/model-integration/src/test/models/tensorflow/dropout/dropout.py diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt index ad431f0460d..ad431f0460d 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 Binary files differindex 000c9b3a7b5..000c9b3a7b5 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index Binary files differindex 9492ef4bde2..9492ef4bde2 100644 --- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt index eb926836576..eb926836576 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 Binary files differindex a7ca01888c7..a7ca01888c7 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index Binary files differindex 7989c109a3a..7989c109a3a 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py index 86a17e81f8f..86a17e81f8f 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py +++ b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py index 07a9fa4a213..07a9fa4a213 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt index 8100dfd594d..8100dfd594d 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differindex 8474aa0a04c..8474aa0a04c 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index Binary files differindex cfcdac20409..cfcdac20409 100644 --- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index +++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index diff --git a/model-integration/src/test/models/xgboost/xgboost.2.2.json b/model-integration/src/test/models/xgboost/xgboost.2.2.json new file mode 100644 index 00000000000..f8949b47e52 --- /dev/null +++ b/model-integration/src/test/models/xgboost/xgboost.2.2.json @@ -0,0 +1,19 @@ +[ + { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 1.71218 }, + { "nodeid": 4, "leaf": -1.70044 } + ]}, + { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [ + { "nodeid": 5, "leaf": -1.94071 }, + { "nodeid": 6, "leaf": 1.85965 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 0.784718 }, + { "nodeid": 4, "leaf": -0.96853 } + ]}, + { "nodeid": 2, "leaf": -6.23624 } + ]} +]
\ No newline at end of file @@ -94,6 +94,7 @@ <module>messagebus</module> <module>metrics</module> <module>model-evaluation</module> + <module>model-integration</module> <module>node-repository</module> <module>node-admin</module> <module>node-maintainer</module> diff --git a/searchlib/pom.xml b/searchlib/pom.xml index f1be4e96269..e0ce822e593 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -35,18 +35,6 @@ <version>${project.version}</version> </dependency> <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>proto</artifactId> - </dependency> - <dependency> - <groupId>org.tensorflow</groupId> - <artifactId>tensorflow</artifactId> - </dependency> - <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> <scope>provided</scope> @@ -117,26 +105,6 @@ </execution> </executions> </plugin> - <plugin> - <groupId>com.github.os72</groupId> - <artifactId>protoc-jar-maven-plugin</artifactId> - <version>3.5.1.1</version> - <executions> - <execution> - <phase>generate-sources</phase> - <goals> - <goal>run</goal> - </goals> - <configuration> - <addSources>main</addSources> - <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory> - <inputDirectories> - <include>src/main/protobuf</include> - </inputDirectories> - </configuration> - </execution> - </executions> - </plugin> </plugins> </build> </project> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java index 59ec66b7209..854a5202916 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; @@ -10,13 +9,11 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.regex.Pattern; @@ -64,8 +61,7 @@ public class ImportedModel { /** * Returns an immutable map of the small constants of this. - * These should have sizes up to a few kb at most, and correspond to constant - * values given in the TensorFlow or ONNX source. + * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } @@ -93,18 +89,18 @@ public class ImportedModel { public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } /** Returns the given signature. If it does not already exist it is added to this. */ - Signature signature(String name) { + public Signature signature(String name) { return signatures.computeIfAbsent(name, Signature::new); } /** Convenience method for returning a default signature */ - Signature defaultSignature() { return signature(defaultSignatureName); } + public Signature defaultSignature() { return signature(defaultSignatureName); } - void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } - void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void function(String name, RankingExpression expression) { functions.put(name, expression); } + public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } + public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + public void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + public void function(String name, RankingExpression expression) { functions.put(name, expression); } /** * Returns all the output expressions of this indexed by name. The names consist of one or two parts @@ -116,9 +112,9 @@ public class ImportedModel { List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), + expressions.add(new Pair<>(signatureEntry.getKey() + "" + outputEntry.getKey(), signatureEntry.getValue().outputExpression(outputEntry.getKey()) - .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); + .withName(signatureEntry.getKey() + "" + outputEntry.getKey()))); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs expressions.add(new Pair<>(signatureEntry.getKey(), new ExpressionFunction(signatureEntry.getKey(), diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java index 40d1ca8030a..55f1eef741c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; @@ -24,27 +23,26 @@ public class ImportedModels { /** All imported models, indexed by their names */ private final ImmutableMap<String, ImportedModel> importedModels; - private static final ImmutableList<ModelImporter> importers = - ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), new XGBoostImporter()); - /** Create a null imported models */ public ImportedModels() { importedModels = ImmutableMap.of(); } - public ImportedModels(File modelsDirectory) { + public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) { Map<String, ImportedModel> models = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models); + importRecursively(modelsDirectory, models, importers); importedModels = ImmutableMap.copyOf(models); } - private static void importRecursively(File dir, Map<String, ImportedModel> models) { + private static void importRecursively(File dir, + Map<String, ImportedModel> models, + Collection<ModelImporter> importers) { if ( ! dir.isDirectory()) return; Arrays.stream(dir.listFiles()).sorted().forEach(child -> { - Optional<ModelImporter> importer = findImporterOf(child); + Optional<ModelImporter> importer = findImporterOf(child, importers); if (importer.isPresent()) { String name = toName(child); ImportedModel existing = models.get(name); @@ -54,12 +52,12 @@ public class ImportedModels { models.put(name, importer.get().importModel(name, child)); } else { - importRecursively(child, models); + importRecursively(child, models, importers); } }); } - private static Optional<ModelImporter> findImporterOf(File path) { + private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) { return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); } @@ -93,7 +91,7 @@ public class ImportedModels { } private static Path stripFileEnding(Path path) { - int dotIndex = path.last().lastIndexOf("."); + int dotIndex = path.last().lastIndexOf(""); if (dotIndex <= 0) return path; return path.withLast(path.last().substring(0, dotIndex)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java index 481b7f9397a..1b6494e8ce8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java @@ -1,11 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -47,7 +45,7 @@ public abstract class ModelImporter { * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. */ - static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { + public static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); graph.optimize(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java index 7fc2aae87d1..ab6a80a193a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.Rename; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java index 1b8c62fe0e9..1f479cd2e0b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index 0eff8e8bc08..aab50f422be 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -180,7 +180,7 @@ public abstract class IntermediateOperation { /** * An interface mapping operation attributes to Vespa Values. - * Adapter for differences in ONNX/TensorFlow. + * Adapter for differences in different model types. */ public interface AttributeMap { Optional<Value> get(String key); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java index 8413ed74118..2d401469e40 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java index 95a77c07590..6ce9abf2ec9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java @@ -3,8 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java index e91c2305f7d..ff87412396d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java @@ -2,8 +2,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -24,8 +24,6 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; - public class Reshape extends IntermediateOperation { public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { @@ -52,7 +50,7 @@ public class Reshape extends IntermediateOperation { int size = cell.getValue().intValue(); if (size < 0) { size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - tensorSize(inputType.type()).intValue(); + OrderedTensorType.tensorSize(inputType.type()).intValue(); } outputTypeBuilder.add(TensorType.Dimension.indexed( String.format("%s_%d", vespaName(), dimensionIndex), size)); @@ -82,7 +80,7 @@ public class Reshape extends IntermediateOperation { } public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!tensorSize(inputType).equals(tensorSize(outputType))) { + if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java index 1530754cc43..bb55ed768a6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java @@ -1,8 +1,8 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. /** - * ONNX integration + * Model integration */ @ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java index b3dafff621c..4bd28a74d6f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import org.junit.Test; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java index 55e1d234782..5118081637e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import org.junit.Test; diff --git a/standalone-container/vespa-standalone-container.spec b/standalone-container/vespa-standalone-container.spec index f051142c516..6143df6a446 100644 --- a/standalone-container/vespa-standalone-container.spec +++ b/standalone-container/vespa-standalone-container.spec @@ -54,6 +54,7 @@ declare -a modules=( jdisc_core jdisc_http_service model-evaluation + model-integration security-utils simplemetrics standalone-container |