summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-21 19:21:22 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-21 19:21:22 +0100
commit99ca9b2907ff637fc6e4e0a61860923ac1c9dee5 (patch)
treed5a5e408d56e9165cd716e9531ab9bcec6a29e4a
parent61cae2609740b51c180b2f507b5e4d0eb399fedc (diff)
Separate model integration into a separate module
This allows us to access model importers (such as TensorFlow) in config models without loading one instance per config model instance, which is not possible with TensorFlow because it depends on JNI code.
-rw-r--r--CMakeLists.txt1
-rw-r--r--config-model/pom.xml14
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java55
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java18
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java5
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java13
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java15
-rw-r--r--configserver/src/main/resources/configserver-app/services.xml1
-rw-r--r--container-dev/pom.xml24
-rw-r--r--container-disc/pom.xml1
-rw-r--r--container/pom.xml12
-rw-r--r--document/src/main/java/com/yahoo/document/Document.java1
-rw-r--r--model-integration/pom.xml100
-rw-r--r--model-integration/src/main/config/model-integration.xml10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java)8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java)5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java)6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java)9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java)14
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java)7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java)8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java)5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java)2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java4
-rw-r--r--model-integration/src/main/protobuf/onnx.proto (renamed from searchlib/src/main/protobuf/onnx.proto)0
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java)10
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java)6
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java)10
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java)17
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java)11
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java)19
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java)4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java)4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java28
-rw-r--r--model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx (renamed from searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx)bin31758 -> 31758 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001)bin1073000 -> 1073000 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index)bin686 -> 686 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001)bin1579020 -> 1579020 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index)bin520 -> 520 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/dropout.py (renamed from searchlib/src/test/files/integration/tensorflow/dropout/dropout.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001)bin31400 -> 31400 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index)bin165 -> 165 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001)bin1066440 -> 1066440 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index)bin308 -> 308 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/simple_mnist.py (renamed from searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001)bin31400 -> 31400 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index)bin159 -> 159 bytes
-rw-r--r--model-integration/src/test/models/xgboost/xgboost.2.2.json19
-rw-r--r--pom.xml1
-rw-r--r--searchlib/pom.xml32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java)26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java)6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java)2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java)2
-rw-r--r--standalone-container/vespa-standalone-container.spec1
97 files changed, 417 insertions, 267 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e97f601fb74..26e9fa8a52a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -81,6 +81,7 @@ add_subdirectory(messagebus)
add_subdirectory(messagebus_test)
add_subdirectory(metrics)
add_subdirectory(model-evaluation)
+add_subdirectory(model-integration)
add_subdirectory(node-repository)
add_subdirectory(orchestrator)
add_subdirectory(persistence)
diff --git a/config-model/pom.xml b/config-model/pom.xml
index a65a6a836ed..b377760249a 100644
--- a/config-model/pom.xml
+++ b/config-model/pom.xml
@@ -293,16 +293,10 @@
<scope>test</scope>
</dependency>
<dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>model-integration</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
</dependency>
</dependencies>
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index 574c25a2f84..d784722f77a 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -21,8 +21,9 @@ import com.yahoo.config.provision.Zone;
import com.yahoo.io.reader.NamedReader;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.SearchBuilder;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import com.yahoo.vespa.config.ConfigDefinition;
import com.yahoo.vespa.config.ConfigDefinitionBuilder;
import com.yahoo.vespa.config.ConfigDefinitionKey;
@@ -39,6 +40,7 @@ import java.io.IOException;
import java.io.Reader;
import java.time.Instant;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
@@ -80,11 +82,23 @@ public class DeployState implements ConfigDefinitionStore {
return new Builder().applicationPackage(applicationPackage).build();
}
- private DeployState(ApplicationPackage applicationPackage, SearchDocumentModel searchDocumentModel, RankProfileRegistry rankProfileRegistry,
- FileRegistry fileRegistry, DeployLogger deployLogger, Optional<HostProvisioner> hostProvisioner, DeployProperties properties,
- Optional<ApplicationPackage> permanentApplicationPackage, Optional<ConfigDefinitionRepo> configDefinitionRepo,
- java.util.Optional<Model> previousModel, Set<Rotation> rotations, Zone zone, QueryProfiles queryProfiles,
- SemanticRules semanticRules, Instant now, Version wantedNodeVespaVersion) {
+ private DeployState(ApplicationPackage applicationPackage,
+ SearchDocumentModel searchDocumentModel,
+ RankProfileRegistry rankProfileRegistry,
+ FileRegistry fileRegistry,
+ DeployLogger deployLogger,
+ Optional<HostProvisioner> hostProvisioner,
+ DeployProperties properties,
+ Optional<ApplicationPackage> permanentApplicationPackage,
+ Optional<ConfigDefinitionRepo> configDefinitionRepo,
+ java.util.Optional<Model> previousModel,
+ Set<Rotation> rotations,
+ Collection<ModelImporter> modelImporters,
+ Zone zone,
+ QueryProfiles queryProfiles,
+ SemanticRules semanticRules,
+ Instant now,
+ Version wantedNodeVespaVersion) {
this.logger = deployLogger;
this.fileRegistry = fileRegistry;
this.rankProfileRegistry = rankProfileRegistry;
@@ -100,7 +114,8 @@ public class DeployState implements ConfigDefinitionStore {
this.zone = zone;
this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated
this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated
- this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR));
+ this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR),
+ modelImporters);
this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty);
this.wantedNodeVespaVersion = wantedNodeVespaVersion;
@@ -232,6 +247,7 @@ public class DeployState implements ConfigDefinitionStore {
private Optional<ConfigDefinitionRepo> configDefinitionRepo = Optional.empty();
private Optional<Model> previousModel = Optional.empty();
private Set<Rotation> rotations = new HashSet<>();
+ private Collection<ModelImporter> modelImporters = Collections.emptyList();
private Zone zone = Zone.defaultZone();
private Instant now = Instant.now();
private Version wantedNodeVespaVersion = Vtag.currentVersion;
@@ -281,6 +297,11 @@ public class DeployState implements ConfigDefinitionStore {
return this;
}
+ public Builder modelImporters(Collection<ModelImporter> modelImporters) {
+ this.modelImporters = modelImporters;
+ return this;
+ }
+
public Builder zone(Zone zone) {
this.zone = zone;
return this;
@@ -305,9 +326,23 @@ public class DeployState implements ConfigDefinitionStore {
QueryProfiles queryProfiles = new QueryProfilesBuilder().build(applicationPackage);
SemanticRules semanticRules = new SemanticRuleBuilder().build(applicationPackage);
SearchDocumentModel searchDocumentModel = createSearchDocumentModel(rankProfileRegistry, logger, queryProfiles, validationParameters);
- return new DeployState(applicationPackage, searchDocumentModel, rankProfileRegistry, fileRegistry, logger, hostProvisioner,
- properties, permanentApplicationPackage, configDefinitionRepo, previousModel, rotations,
- zone, queryProfiles, semanticRules, now, wantedNodeVespaVersion);
+ return new DeployState(applicationPackage,
+ searchDocumentModel,
+ rankProfileRegistry,
+ fileRegistry,
+ logger,
+ hostProvisioner,
+ properties,
+ permanentApplicationPackage,
+ configDefinitionRepo,
+ previousModel,
+ rotations,
+ modelImporters,
+ zone,
+ queryProfiles,
+ semanticRules,
+ now,
+ wantedNodeVespaVersion);
}
private SearchDocumentModel createSearchDocumentModel(RankProfileRegistry rankProfileRegistry,
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 78f61d7192d..a0cbc4271f3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -16,7 +16,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 0c5c7733dda..eb221a80638 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -12,7 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.derived.validation.Validation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import java.io.IOException;
import java.io.Writer;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index fcbfb47c597..c955cd9956b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.RankingConstants;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.Search;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index a1b0e72051b..fd468c8eca7 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -9,7 +9,7 @@ import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
index 2fe2dacf2ce..f52eedc8d08 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
@@ -4,7 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import java.util.HashMap;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index fef198f7939..b4de5e925f6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -33,8 +33,8 @@ import com.yahoo.searchdefinition.derived.RankProfileList;
import com.yahoo.searchdefinition.processing.Processing;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import com.yahoo.vespa.model.ml.ConvertedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.vespa.config.ConfigKey;
import com.yahoo.vespa.config.ConfigPayload;
@@ -150,7 +150,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
this(configModelRegistry, deployState, true, null);
}
- private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) throws IOException, SAXException {
+ private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor)
+ throws IOException, SAXException {
super("vespamodel");
this.validationOverrides = deployState.validationOverrides();
configModelRegistry = new VespaConfigModelRegistry(configModelRegistry);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
index 2af9b297e9e..ff64b6753f9 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
@@ -21,6 +21,7 @@ import com.yahoo.config.model.deploy.DeployProperties;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.provision.Version;
import com.yahoo.config.provision.Zone;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import com.yahoo.vespa.config.VespaVersion;
import com.yahoo.vespa.model.application.validation.Validation;
@@ -29,6 +30,8 @@ import org.xml.sax.SAXException;
import java.io.IOException;
import java.time.Clock;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
@@ -41,19 +44,17 @@ public class VespaModelFactory implements ModelFactory {
private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName());
private final ConfigModelRegistry configModelRegistry;
+ private final Collection<ModelImporter> modelImporters;
private final Zone zone;
private final Clock clock;
private final Version version;
/** Creates a factory for vespa models for this version of the source */
@Inject
- public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) {
- this(Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), pluginRegistry, zone);
- }
-
- /** Creates a factory for vespa models of a particular version */
- public VespaModelFactory(Version version, ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) {
- this.version = version;
+ public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry,
+ ComponentRegistry<ModelImporter> modelImporters,
+ Zone zone) {
+ this.version = Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro);
List<ConfigModelBuilder> modelBuilders = new ArrayList<>();
for (ConfigModelPlugin plugin : pluginRegistry.allComponents()) {
if (plugin instanceof ConfigModelBuilder) {
@@ -61,6 +62,7 @@ public class VespaModelFactory implements ModelFactory {
}
}
this.configModelRegistry = new MapConfigModelRegistry(modelBuilders);
+ this.modelImporters = modelImporters.allComponents();
this.zone = zone;
this.clock = Clock.systemUTC();
}
@@ -79,6 +81,7 @@ public class VespaModelFactory implements ModelFactory {
} else {
this.configModelRegistry = configModelRegistry;
}
+ this.modelImporters = Collections.emptyList();
this.zone = Zone.defaultZone();
this.clock = clock;
}
@@ -137,6 +140,7 @@ public class VespaModelFactory implements ModelFactory {
.properties(createDeployProperties(modelContext.properties()))
.modelHostProvisioner(createHostProvisioner(modelContext))
.rotations(modelContext.properties().rotations())
+ .modelImporters(modelImporters)
.zone(zone)
.now(clock.instant())
.wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion());
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 30586b1e677..18f3cd6e088 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -19,7 +19,7 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -41,7 +41,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
-import java.io.Reader;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
index 07a36832094..ad1676ee7e1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.yolean.Exceptions;
import org.junit.Test;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 4df3add13c5..14847938f68 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.Iterator;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
index 8df3985fd24..4f46d813ba5 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index 150469cc928..551ef6968d2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -3,7 +3,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.yolean.Exceptions;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
index e507a6c48e4..79adb1cf6bf 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
@@ -6,8 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
-import com.yahoo.yolean.Exceptions;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.Optional;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 6a1e5b207c6..d97dfb0cbf9 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.List;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
index 5e649c2e551..dedfbce6ae8 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
@@ -4,9 +4,8 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.yolean.Exceptions;
-import org.junit.Ignore;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
index f39896f5779..437f43976b7 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
@@ -3,12 +3,11 @@ package com.yahoo.searchdefinition.derived;
import com.yahoo.document.DocumenttypesConfig;
import com.yahoo.document.config.DocumentmanagerConfig;
-import com.yahoo.io.IOUtils;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.configmodel.producers.DocumentManager;
import com.yahoo.vespa.configmodel.producers.DocumentTypes;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
index f4344c9b03c..409b4236a9c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
@@ -9,12 +9,9 @@ import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
-import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
-import java.io.IOException;
-
/**
* Tests deriving rank for files from search definitions
*
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
index 2bae285301c..f2dd16577d1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
@@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
index 723cd58a34a..93b366c8d08 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.io.File;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
index 8941b07101d..3a44c123f05 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
@@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
index d38bce04617..1c368ff6f10 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
@@ -6,13 +6,11 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
-import com.yahoo.searchdefinition.derived.Deriver;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
-import org.junit.Ignore;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
-import java.io.File;
+
import java.io.IOException;
import static org.junit.Assert.assertEquals;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index cff9abb08ed..7f590d4b230 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import com.google.common.collect.ImmutableList;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.path.Path;
@@ -10,7 +11,11 @@ import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import java.util.HashMap;
import java.util.List;
@@ -26,6 +31,9 @@ import static org.junit.Assert.assertEquals;
*/
class RankProfileSearchFixture {
+ private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new XGBoostImporter());
private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
private final QueryProfileRegistry queryProfileRegistry;
private Search search;
@@ -83,7 +91,8 @@ class RankProfileSearchFixture {
public RankProfile compileRankProfile(String rankProfile, Path applicationDir) {
RankProfile compiled = rankProfileRegistry.get(search, rankProfile)
- .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile()));
+ .compile(queryProfileRegistry,
+ new ImportedModels(applicationDir.toFile(), importers));
compiledRankProfiles.put(rankProfile, compiled);
return compiled;
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index fd048737b43..78dcae53f7e 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.io.IOException;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index 6e3a227e2a9..7eee18c1059 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.List;
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
index 2ae629562d0..4d8bdb7a6c6 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
@@ -1,11 +1,17 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
+import com.google.common.collect.ImmutableList;
import com.yahoo.config.model.ApplicationPackageTester;
+import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.model.VespaModel;
@@ -26,6 +32,10 @@ import static org.junit.Assert.assertEquals;
*/
public class ImportedModelTester {
+ private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new XGBoostImporter());
+
private final String modelName;
private final Path applicationDir;
@@ -36,7 +46,10 @@ public class ImportedModelTester {
public VespaModel createVespaModel() {
try {
- return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app());
+ DeployState.Builder state = new DeployState.Builder();
+ state.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app());
+ state.modelImporters(importers);
+ return new VespaModel(state.build());
}
catch (SAXException | IOException e) {
throw new RuntimeException(e);
diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml
index 5f60be8c202..3dd6e0090c5 100644
--- a/configserver/src/main/resources/configserver-app/services.xml
+++ b/configserver/src/main/resources/configserver-app/services.xml
@@ -55,6 +55,7 @@
<preprocess:include file='config-models.xml' required='false' />
<preprocess:include file='node-repository.xml' required='false' />
<preprocess:include file='hosted-vespa/routing-status.xml' required='false' />
+ <preprocess:include file='model-integration.xml' required='true' />
<!-- TODO Vespa 7: Remove scoreboard.xml, replaced by metrics-packets.xml -->
<preprocess:include file='hosted-vespa/scoreboard.xml' required='false' />
diff --git a/container-dev/pom.xml b/container-dev/pom.xml
index 1b0f36e3adb..7d61882c085 100644
--- a/container-dev/pom.xml
+++ b/container-dev/pom.xml
@@ -104,18 +104,6 @@
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
</exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </exclusion>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -178,18 +166,6 @@
<groupId>xerces</groupId>
<artifactId>xercesImpl</artifactId>
</exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </exclusion>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/container-disc/pom.xml b/container-disc/pom.xml
index 62985192ad3..be4ac23a938 100644
--- a/container-disc/pom.xml
+++ b/container-disc/pom.xml
@@ -174,6 +174,7 @@
jdisc-security-filters-jar-with-dependencies.jar,
jdisc_http_service-jar-with-dependencies.jar,
model-evaluation-jar-with-dependencies.jar,
+ model-integration-jar-with-dependencies.jar,
vespaclient-container-plugin-jar-with-dependencies.jar,
vespa-athenz-jar-with-dependencies.jar,
security-utils-jar-with-dependencies.jar,
diff --git a/container/pom.xml b/container/pom.xml
index 32a7947d6d5..529f01e0a40 100644
--- a/container/pom.xml
+++ b/container/pom.xml
@@ -53,18 +53,6 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </exclusion>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
</exclusions>
</dependency>
</dependencies>
diff --git a/document/src/main/java/com/yahoo/document/Document.java b/document/src/main/java/com/yahoo/document/Document.java
index 85049baf8b0..222ebe29c6d 100644
--- a/document/src/main/java/com/yahoo/document/Document.java
+++ b/document/src/main/java/com/yahoo/document/Document.java
@@ -286,6 +286,7 @@ public class Document extends StructuredFieldValue {
/**
* Get JSON representation of the document root and its children contained in a JSON object
+ *
* @return JSON representation of document
*/
public String toJson() {
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
new file mode 100644
index 00000000000..28a00dcbdbc
--- /dev/null
+++ b/model-integration/pom.xml
@@ -0,0 +1,100 @@
+<?xml version="1.0"?>
+<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
+ http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>parent</artifactId>
+ <version>6-SNAPSHOT</version>
+ <relativePath>../parent/pom.xml</relativePath>
+ </parent>
+ <artifactId>model-integration</artifactId>
+ <version>6-SNAPSHOT</version>
+ <packaging>container-plugin</packaging>
+ <dependencies>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>component</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>vespajlib</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>searchlib</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>bundle-plugin</artifactId>
+ <extensions>true</extensions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <compilerArgs>
+ <arg>-Xlint:rawtypes</arg>
+ <arg>-Xlint:unchecked</arg>
+ <arg>-Werror</arg>
+ </compilerArgs>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>com.github.os72</groupId>
+ <artifactId>protoc-jar-maven-plugin</artifactId>
+ <version>3.5.1.1</version>
+ <executions>
+ <execution>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <addSources>main</addSources>
+ <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory>
+ <inputDirectories>
+ <include>src/main/protobuf</include>
+ </inputDirectories>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/model-integration/src/main/config/model-integration.xml b/model-integration/src/main/config/model-integration.xml
new file mode 100644
index 00000000000..da45ce23575
--- /dev/null
+++ b/model-integration/src/main/config/model-integration.xml
@@ -0,0 +1,10 @@
+<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<!-- Component which can import some ml model.
+ This is included into the config server services.xml to enable it to translate
+ model pseudofeatures in ranking expressions during config mddel building.
+ It is provided as separate bundles instead of being included in the config model
+ because some of these (TensorFlow) includes
+ JNI code, and so can only exist in one instance in the server. -->
+<component id="ai.vespa.rankingexpression.importer.onnx.OnnxImporter" bundle="model-integration" />
+<component id="ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter" bundle="model-integration" />
+<component id="ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter" bundle="model-integration" />
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 3fe92440cae..81d0753ea4b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+package ai.vespa.rankingexpression.importer.onnx;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
@@ -28,9 +28,9 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class GraphImporter {
+class GraphImporter {
- public static IntermediateOperation mapOperation(Onnx.NodeProto node,
+ private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
IntermediateGraph graph) {
String nodeName = node.getName();
@@ -79,7 +79,7 @@ public class GraphImporter {
return op;
}
- public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
+ static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
Onnx.GraphProto onnxGraph = model.getGraph();
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
index e6bb5f40b3f..0418581d7b2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.onnx;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import onnx.Onnx;
import java.io.File;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index 18856d4a25f..6dd33c79852 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+package ai.vespa.rankingexpression.importer.onnx;
import com.google.protobuf.ByteString;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
@@ -17,9 +17,9 @@ import java.nio.FloatBuffer;
*
* @author lesters
*/
-public class TensorConverter {
+class TensorConverter {
- public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) {
+ static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) {
Values values = readValuesOf(tensorProto);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
for (int i = 0; i < values.size(); i++) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 715c55d8323..43ceaa747b7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+package ai.vespa.rankingexpression.importer.onnx;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
@@ -11,9 +11,9 @@ import onnx.Onnx;
*
* @author lesters
*/
-public class TypeConverter {
+class TypeConverter {
- public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
+ static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
@@ -30,11 +30,11 @@ public class TypeConverter {
}
}
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
+ static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
}
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (int i = 0; i < shape.getDimCount(); ++ i) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java
new file mode 100644
index 00000000000..9599cf8627c
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java
@@ -0,0 +1,5 @@
+@ExportPackage
+
+package ai.vespa.rankingexpression.importer.onnx;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
index 89b75e8e3e2..73310c78cab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -20,15 +20,15 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class AttributeConverter implements IntermediateOperation.AttributeMap {
+class AttributeConverter implements IntermediateOperation.AttributeMap {
private final Map<String, AttrValue> attributeMap;
- public AttributeConverter(NodeDef node) {
+ private AttributeConverter(NodeDef node) {
attributeMap = node.getAttrMap();
}
- public static AttributeConverter convert(NodeDef node) {
+ static AttributeConverter convert(NodeDef node) {
return new AttributeConverter(node);
}
@@ -83,4 +83,5 @@ public class AttributeConverter implements IntermediateOperation.AttributeMap {
}
return Optional.empty();
}
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index e1b292f9e61..c012bc3c54f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
@@ -43,9 +43,9 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class GraphImporter {
+class GraphImporter {
- public static IntermediateOperation mapOperation(NodeDef node,
+ private static IntermediateOperation mapOperation(NodeDef node,
List<IntermediateOperation> inputs,
IntermediateGraph graph) {
String nodeName = node.getName();
@@ -112,7 +112,7 @@ public class GraphImporter {
return op;
}
- public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
+ static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
@@ -209,7 +209,7 @@ public class GraphImporter {
throw new IllegalArgumentException("Could not find node '" + name + "'");
}
- public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
+ static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
Session.Runner fetched = bundle.session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
if (importedTensors.size() != 1)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index d2d0acfc964..4e67286ef09 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
@@ -26,7 +26,7 @@ public class TensorConverter {
return toVespaTensor(tfTensor, "d");
}
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
@@ -35,7 +35,7 @@ public class TensorConverter {
return builder.build();
}
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
+ static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
for (int i = 0; i < values.size(); i++) {
@@ -44,7 +44,7 @@ public class TensorConverter {
return builder.build();
}
- public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
+ static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
Values values = readValuesOf(tensorProto);
for (int i = 0; i < values.size(); ++i) {
@@ -71,7 +71,7 @@ public class TensorConverter {
return size;
}
- public static Long dimensionSize(TensorType.Dimension dim) {
+ private static Long dimensionSize(TensorType.Dimension dim) {
return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
}
@@ -182,8 +182,8 @@ public class TensorConverter {
}
private static abstract class ProtoValues extends Values {
- protected final TensorProto tensorProto;
- protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
+ final TensorProto tensorProto;
+ ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
}
private static class ProtoBoolValues extends ProtoValues {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index 7c18e04bae7..f8a25e4d94c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -1,8 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import org.tensorflow.SavedModelBundle;
import java.io.File;
@@ -47,7 +48,7 @@ public class TensorFlowImporter extends ModelImporter {
}
/** Imports a TensorFlow model */
- ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
+ public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
try {
IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
return convertIntermediateGraphToModel(graph, modelDir);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index 67ad1edc312..a5a506fdb6d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
@@ -15,9 +15,9 @@ import java.util.List;
*
* @author lesters
*/
-public class TypeConverter {
+class TypeConverter {
- public static void verifyType(NodeDef node, OrderedTensorType type) {
+ static void verifyType(NodeDef node, OrderedTensorType type) {
TensorShapeProto shape = tensorFlowShape(node);
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
@@ -50,11 +50,11 @@ public class TypeConverter {
return shapeList.get(0); // support multiple outputs?
}
- public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ static OrderedTensorType fromTensorFlowType(NodeDef node) {
return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
}
- public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
TensorShapeProto shape = tensorFlowShape(node);
for (int i = 0; i < shape.getDimCount(); ++ i) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
index 25bac27f315..b777ee07e58 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
@@ -1,9 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
@@ -16,7 +14,7 @@ import java.nio.charset.StandardCharsets;
*
* @author bratseth
*/
-public class VariableConverter {
+class VariableConverter {
/**
* Reads the tensor with the given TensorFlow name at the given model location,
@@ -24,7 +22,7 @@ public class VariableConverter {
* Note that order of dimensions in the tensor type does matter as the TensorFlow tensor
* tensor dimensions are implicitly ordered.
*/
- public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
+ static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
bundle),
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java
new file mode 100644
index 00000000000..0840e584d25
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java
@@ -0,0 +1,4 @@
+@ExportPackage
+package ai.vespa.rankingexpression.importer.tensorflow;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index 725f319a839..e87dd265d50 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,8 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.xgboost;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost.XGBoostParser;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import java.io.File;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
index fef8bfec81d..2b215f816f5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
+package ai.vespa.rankingexpression.importer.xgboost;
import java.io.File;
import java.io.IOException;
@@ -13,7 +13,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
/**
* @author grace-lam
*/
-public class XGBoostParser {
+class XGBoostParser {
private List<XGBoostTree> xgboostTrees;
@@ -24,7 +24,7 @@ public class XGBoostParser {
* @throws JsonProcessingException Fails JSON parsing.
* @throws IOException Fails file reading.
*/
- public XGBoostParser(String filePath) throws JsonProcessingException, IOException {
+ XGBoostParser(String filePath) throws JsonProcessingException, IOException {
this.xgboostTrees = new ArrayList<>();
ObjectMapper mapper = new ObjectMapper();
JsonNode forestNode = mapper.readTree(new File(filePath));
@@ -38,7 +38,7 @@ public class XGBoostParser {
*
* @return Vespa ranking expressions.
*/
- public String toRankingExpression() {
+ String toRankingExpression() {
StringBuilder ret = new StringBuilder();
for (int i = 0; i < xgboostTrees.size(); i++) {
ret.append(treeToRankExp(xgboostTrees.get(i)));
@@ -55,7 +55,7 @@ public class XGBoostParser {
* @param node XGBoost tree node to convert.
* @return Vespa ranking expression for input node.
*/
- public String treeToRankExp(XGBoostTree node) {
+ private String treeToRankExp(XGBoostTree node) {
if (node.isLeaf()) {
return Double.toString(node.getLeaf());
} else {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java
index 6bbc9abe8ae..e32e0f1eab5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
+package ai.vespa.rankingexpression.importer.xgboost;
import java.util.List;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java
new file mode 100644
index 00000000000..d310de9041b
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java
@@ -0,0 +1,4 @@
+@ExportPackage
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/searchlib/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto
index dc6542867e0..dc6542867e0 100644
--- a/searchlib/src/main/protobuf/onnx.proto
+++ b/model-integration/src/main/protobuf/onnx.proto
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 6d0ee0a906d..a71d7a42551 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -1,11 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.onnx;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -21,7 +23,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
@@ -55,8 +57,8 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testComparisonBetweenOnnxAndTensorflow() {
- String tfModelPath = "src/test/files/integration/tensorflow/mnist_softmax/saved";
- String onnxModelPath = "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx";
+ String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved";
+ String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx";
Tensor argument = placeholderArgument();
Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
index e325c3d11b4..acd649d985b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -16,7 +16,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test",
- "src/test/files/integration/tensorflow/batch_norm/saved");
+ "src/test/models/tensorflow/batch_norm/saved");
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
index 5b0cca6a940..a878b284b2c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -17,8 +17,6 @@ import org.tensorflow.Session;
import java.nio.FloatBuffer;
import java.util.List;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTensorFlowModel.contextFrom;
-
/**
* Microbenchmark of imported ML model evaluation.
*
@@ -26,13 +24,13 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTenso
*/
public class BlogEvaluationBenchmark {
- static final String modelDir = "src/test/files/integration/tensorflow/blog/saved";
+ static final String modelDir = "src/test/models/tensorflow/blog/saved";
public static void main(String[] args) throws ParseException {
SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel);
- Context context = contextFrom(model);
+ Context context = TestableTensorFlowModel.contextFrom(model);
Tensor u = generateInputTensor();
Tensor d = generateInputTensor();
context.put("input_u", new TensorValue(u));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index 8ca5a9a7888..6e58761e5ce 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.tensor.TensorType;
+import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -17,18 +18,18 @@ public class DropoutImportTestCase {
@Test
public void testDropoutImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/dropout/saved");
// Check required functions
- assertEquals(1, model.get().inputs().size());
+ Assert.assertEquals(1, model.get().inputs().size());
assertTrue(model.get().inputs().containsKey("X"));
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().inputs().get("X"));
+ Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().inputs().get("X"));
ImportedModel.Signature signature = model.get().signature("serving_default");
- assertEquals("Has skipped outputs",
- 0, model.get().signature("serving_default").skippedOutputs().size());
+ Assert.assertEquals("Has skipped outputs",
+ 0, model.get().signature("serving_default").skippedOutputs().size());
ExpressionFunction function = signature.outputExpression("y");
assertNotNull(function);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
index 3d8d5d5a570..b338f46fb4d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
@@ -1,8 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -15,11 +16,11 @@ public class MnistImportTestCase {
@Test
public void testMnistImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist/saved");
ImportedModel.Signature signature = model.get().signature("serving_default");
- assertEquals("Has skipped outputs",
- 0, model.get().signature("serving_default").skippedOutputs().size());
+ Assert.assertEquals("Has skipped outputs",
+ 0, model.get().signature("serving_default").skippedOutputs().size());
ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
index feba40601e3..7e8cbef8ada 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
@@ -1,10 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -18,10 +19,10 @@ public class TensorFlowMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist_softmax/saved");
// Check constants
- assertEquals(2, model.get().largeConstants().size());
+ Assert.assertEquals(2, model.get().largeConstants().size());
Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
assertNotNull(constant0);
@@ -36,16 +37,16 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals(10, constant1.size());
// Check (provided) functions
- assertEquals(0, model.get().functions().size());
+ Assert.assertEquals(0, model.get().functions().size());
// Check required functions
- assertEquals(1, model.get().inputs().size());
+ Assert.assertEquals(1, model.get().inputs().size());
assertTrue(model.get().inputs().containsKey("Placeholder"));
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().inputs().get("Placeholder"));
+ Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().inputs().get("Placeholder"));
// Check signatures
- assertEquals(1, model.get().signatures().size());
+ Assert.assertEquals(1, model.get().signatures().size());
ImportedModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index fbe7c5fac63..dbed537885e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java
index aabbdf33c9d..c9fffe143b4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import org.junit.Test;
@@ -11,7 +11,7 @@ public class VariableConverterTestCase {
@Test
public void testConversion() {
- byte[] converted = VariableConverter.importVariable("src/test/files/integration/tensorflow/mnist_softmax/saved",
+ byte[] converted = VariableConverter.importVariable("src/test/models/tensorflow/mnist_softmax/saved",
"Variable_1",
"tensor(d0[10],d1[1])");
assertEquals("{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"1\",\"d1\":\"0\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"2\",\"d1\":\"0\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"3\",\"d1\":\"0\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"4\",\"d1\":\"0\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"5\",\"d1\":\"0\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"6\",\"d1\":\"0\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"7\",\"d1\":\"0\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"8\",\"d1\":\"0\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"9\",\"d1\":\"0\"},\"value\":-0.2573896050453186}]}",
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
new file mode 100644
index 00000000000..48c7f5bee19
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -0,0 +1,28 @@
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author bratseth
+ */
+public class XGBoostImportTestCase {
+
+ @Test
+ public void testXGBoost() {
+ ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json");
+ assertTrue("All inputs are scalar", model.inputs().isEmpty());
+ assertEquals(1, model.expressions().size());
+ System.out.println(model.expressions().keySet());
+ RankingExpression expression = model.expressions().get("test");
+ assertNotNull(expression);
+ assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)",
+ expression.getRoot().toString());
+ }
+
+}
diff --git a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx
index a86019bf53a..a86019bf53a 100644
--- a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx
+++ b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py
index bc6ea13ebc1..bc6ea13ebc1 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py
+++ b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt
index f3ce68a1cbd..f3ce68a1cbd 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001
index 875e8361e10..875e8361e10 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index
index 46c7b258cf5..46c7b258cf5 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt
index a669e69b709..a669e69b709 100644
--- a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001
index 1efd102aef9..1efd102aef9 100644
--- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index
index 56c60dbe529..56c60dbe529 100644
--- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py b/model-integration/src/test/models/tensorflow/dropout/dropout.py
index 42c15cd2812..42c15cd2812 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py
+++ b/model-integration/src/test/models/tensorflow/dropout/dropout.py
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt
index ad431f0460d..ad431f0460d 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001
index 000c9b3a7b5..000c9b3a7b5 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index
index 9492ef4bde2..9492ef4bde2 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt
index eb926836576..eb926836576 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001
index a7ca01888c7..a7ca01888c7 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index
index 7989c109a3a..7989c109a3a 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py
index 86a17e81f8f..86a17e81f8f 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py
+++ b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
index 07a9fa4a213..07a9fa4a213 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt
index 8100dfd594d..8100dfd594d 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
index 8474aa0a04c..8474aa0a04c 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index
index cfcdac20409..cfcdac20409 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index
Binary files differ
diff --git a/model-integration/src/test/models/xgboost/xgboost.2.2.json b/model-integration/src/test/models/xgboost/xgboost.2.2.json
new file mode 100644
index 00000000000..f8949b47e52
--- /dev/null
+++ b/model-integration/src/test/models/xgboost/xgboost.2.2.json
@@ -0,0 +1,19 @@
+[
+ { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 1.71218 },
+ { "nodeid": 4, "leaf": -1.70044 }
+ ]},
+ { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [
+ { "nodeid": 5, "leaf": -1.94071 },
+ { "nodeid": 6, "leaf": 1.85965 }
+ ]}
+ ]},
+ { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 0.784718 },
+ { "nodeid": 4, "leaf": -0.96853 }
+ ]},
+ { "nodeid": 2, "leaf": -6.23624 }
+ ]}
+] \ No newline at end of file
diff --git a/pom.xml b/pom.xml
index f749adbdabf..386aa5bc5a2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -94,6 +94,7 @@
<module>messagebus</module>
<module>metrics</module>
<module>model-evaluation</module>
+ <module>model-integration</module>
<module>node-repository</module>
<module>node-admin</module>
<module>node-maintainer</module>
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index f1be4e96269..e0ce822e593 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -35,18 +35,6 @@
<version>${project.version}</version>
</dependency>
<dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </dependency>
- <dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>provided</scope>
@@ -117,26 +105,6 @@
</execution>
</executions>
</plugin>
- <plugin>
- <groupId>com.github.os72</groupId>
- <artifactId>protoc-jar-maven-plugin</artifactId>
- <version>3.5.1.1</version>
- <executions>
- <execution>
- <phase>generate-sources</phase>
- <goals>
- <goal>run</goal>
- </goals>
- <configuration>
- <addSources>main</addSources>
- <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory>
- <inputDirectories>
- <include>src/main/protobuf</include>
- </inputDirectories>
- </configuration>
- </execution>
- </executions>
- </plugin>
</plugins>
</build>
</project>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java
index 59ec66b7209..854a5202916 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
@@ -10,13 +9,11 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.Optional;
import java.util.regex.Pattern;
@@ -64,8 +61,7 @@ public class ImportedModel {
/**
* Returns an immutable map of the small constants of this.
- * These should have sizes up to a few kb at most, and correspond to constant
- * values given in the TensorFlow or ONNX source.
+ * These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
@@ -93,18 +89,18 @@ public class ImportedModel {
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
/** Returns the given signature. If it does not already exist it is added to this. */
- Signature signature(String name) {
+ public Signature signature(String name) {
return signatures.computeIfAbsent(name, Signature::new);
}
/** Convenience method for returning a default signature */
- Signature defaultSignature() { return signature(defaultSignatureName); }
+ public Signature defaultSignature() { return signature(defaultSignatureName); }
- void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void function(String name, RankingExpression expression) { functions.put(name, expression); }
+ public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
+ public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ public void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ public void function(String name, RankingExpression expression) { functions.put(name, expression); }
/**
* Returns all the output expressions of this indexed by name. The names consist of one or two parts
@@ -116,9 +112,9 @@ public class ImportedModel {
List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
+ expressions.add(new Pair<>(signatureEntry.getKey() + "" + outputEntry.getKey(),
signatureEntry.getValue().outputExpression(outputEntry.getKey())
- .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
+ .withName(signatureEntry.getKey() + "" + outputEntry.getKey())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
new ExpressionFunction(signatureEntry.getKey(),
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java
index 40d1ca8030a..55f1eef741c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
@@ -24,27 +23,26 @@ public class ImportedModels {
/** All imported models, indexed by their names */
private final ImmutableMap<String, ImportedModel> importedModels;
- private static final ImmutableList<ModelImporter> importers =
- ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), new XGBoostImporter());
-
/** Create a null imported models */
public ImportedModels() {
importedModels = ImmutableMap.of();
}
- public ImportedModels(File modelsDirectory) {
+ public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
Map<String, ImportedModel> models = new HashMap<>();
// Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, models);
+ importRecursively(modelsDirectory, models, importers);
importedModels = ImmutableMap.copyOf(models);
}
- private static void importRecursively(File dir, Map<String, ImportedModel> models) {
+ private static void importRecursively(File dir,
+ Map<String, ImportedModel> models,
+ Collection<ModelImporter> importers) {
if ( ! dir.isDirectory()) return;
Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
- Optional<ModelImporter> importer = findImporterOf(child);
+ Optional<ModelImporter> importer = findImporterOf(child, importers);
if (importer.isPresent()) {
String name = toName(child);
ImportedModel existing = models.get(name);
@@ -54,12 +52,12 @@ public class ImportedModels {
models.put(name, importer.get().importModel(name, child));
}
else {
- importRecursively(child, models);
+ importRecursively(child, models, importers);
}
});
}
- private static Optional<ModelImporter> findImporterOf(File path) {
+ private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
}
@@ -93,7 +91,7 @@ public class ImportedModels {
}
private static Path stripFileEnding(Path path) {
- int dotIndex = path.last().lastIndexOf(".");
+ int dotIndex = path.last().lastIndexOf("");
if (dotIndex <= 0) return path;
return path.withLast(path.last().substring(0, dotIndex));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java
index 481b7f9397a..1b6494e8ce8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java
@@ -1,11 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -47,7 +45,7 @@ public abstract class ModelImporter {
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
*/
- static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
+ public static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
ImportedModel model = new ImportedModel(graph.name(), modelSource);
graph.optimize();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
index 7fc2aae87d1..ab6a80a193a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
index 1b8c62fe0e9..1f479cd2e0b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 0eff8e8bc08..aab50f422be 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -180,7 +180,7 @@ public abstract class IntermediateOperation {
/**
* An interface mapping operation attributes to Vespa Values.
- * Adapter for differences in ONNX/TensorFlow.
+ * Adapter for differences in different model types.
*/
public interface AttributeMap {
Optional<Value> get(String key);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
index 8413ed74118..2d401469e40 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 95a77c07590..6ce9abf2ec9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -3,8 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index e91c2305f7d..ff87412396d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -2,8 +2,8 @@
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -24,8 +24,6 @@ import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-
public class Reshape extends IntermediateOperation {
public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
@@ -52,7 +50,7 @@ public class Reshape extends IntermediateOperation {
int size = cell.getValue().intValue();
if (size < 0) {
size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- tensorSize(inputType.type()).intValue();
+ OrderedTensorType.tensorSize(inputType.type()).intValue();
}
outputTypeBuilder.add(TensorType.Dimension.indexed(
String.format("%s_%d", vespaName(), dimensionIndex), size));
@@ -82,7 +80,7 @@ public class Reshape extends IntermediateOperation {
}
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!tensorSize(inputType).equals(tensorSize(outputType))) {
+ if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java
index 1530754cc43..bb55ed768a6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
/**
- * ONNX integration
+ * Model integration
*/
@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java
index b3dafff621c..4bd28a74d6f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import org.junit.Test;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java
index 55e1d234782..5118081637e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import org.junit.Test;
diff --git a/standalone-container/vespa-standalone-container.spec b/standalone-container/vespa-standalone-container.spec
index f051142c516..6143df6a446 100644
--- a/standalone-container/vespa-standalone-container.spec
+++ b/standalone-container/vespa-standalone-container.spec
@@ -54,6 +54,7 @@ declare -a modules=(
jdisc_core
jdisc_http_service
model-evaluation
+ model-integration
security-utils
simplemetrics
standalone-container