summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt1
-rw-r--r--config-model/pom.xml14
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java55
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java18
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java5
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java13
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java15
-rw-r--r--configserver/src/main/resources/configserver-app/services.xml1
-rw-r--r--container-dev/pom.xml24
-rw-r--r--container-disc/pom.xml1
-rw-r--r--container/pom.xml12
-rw-r--r--document/src/main/java/com/yahoo/document/Document.java1
-rw-r--r--model-integration/pom.xml100
-rw-r--r--model-integration/src/main/config/model-integration.xml10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java)8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java)5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java)6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java)9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java)14
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java)7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java)8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java)5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java)10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java)2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java4
-rw-r--r--model-integration/src/main/protobuf/onnx.proto (renamed from searchlib/src/main/protobuf/onnx.proto)0
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java)10
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java)6
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java)10
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java)17
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java)11
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java)19
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java)4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java)4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java28
-rw-r--r--model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx (renamed from searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx)bin31758 -> 31758 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001)bin1073000 -> 1073000 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index)bin686 -> 686 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001)bin1579020 -> 1579020 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index)bin520 -> 520 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/dropout.py (renamed from searchlib/src/test/files/integration/tensorflow/dropout/dropout.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001)bin31400 -> 31400 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index)bin165 -> 165 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001)bin1066440 -> 1066440 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index)bin308 -> 308 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist/simple_mnist.py (renamed from searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt)0
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001)bin31400 -> 31400 bytes
-rw-r--r--model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index (renamed from searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index)bin159 -> 159 bytes
-rw-r--r--model-integration/src/test/models/xgboost/xgboost.2.2.json19
-rw-r--r--pom.xml1
-rw-r--r--searchlib/pom.xml32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java)26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java)6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java)2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java)2
-rw-r--r--standalone-container/vespa-standalone-container.spec1
97 files changed, 417 insertions, 267 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e97f601fb74..26e9fa8a52a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -81,6 +81,7 @@ add_subdirectory(messagebus)
add_subdirectory(messagebus_test)
add_subdirectory(metrics)
add_subdirectory(model-evaluation)
+add_subdirectory(model-integration)
add_subdirectory(node-repository)
add_subdirectory(orchestrator)
add_subdirectory(persistence)
diff --git a/config-model/pom.xml b/config-model/pom.xml
index a65a6a836ed..b377760249a 100644
--- a/config-model/pom.xml
+++ b/config-model/pom.xml
@@ -293,16 +293,10 @@
<scope>test</scope>
</dependency>
<dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>model-integration</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
</dependency>
</dependencies>
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index 574c25a2f84..d784722f77a 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -21,8 +21,9 @@ import com.yahoo.config.provision.Zone;
import com.yahoo.io.reader.NamedReader;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.SearchBuilder;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import com.yahoo.vespa.config.ConfigDefinition;
import com.yahoo.vespa.config.ConfigDefinitionBuilder;
import com.yahoo.vespa.config.ConfigDefinitionKey;
@@ -39,6 +40,7 @@ import java.io.IOException;
import java.io.Reader;
import java.time.Instant;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
@@ -80,11 +82,23 @@ public class DeployState implements ConfigDefinitionStore {
return new Builder().applicationPackage(applicationPackage).build();
}
- private DeployState(ApplicationPackage applicationPackage, SearchDocumentModel searchDocumentModel, RankProfileRegistry rankProfileRegistry,
- FileRegistry fileRegistry, DeployLogger deployLogger, Optional<HostProvisioner> hostProvisioner, DeployProperties properties,
- Optional<ApplicationPackage> permanentApplicationPackage, Optional<ConfigDefinitionRepo> configDefinitionRepo,
- java.util.Optional<Model> previousModel, Set<Rotation> rotations, Zone zone, QueryProfiles queryProfiles,
- SemanticRules semanticRules, Instant now, Version wantedNodeVespaVersion) {
+ private DeployState(ApplicationPackage applicationPackage,
+ SearchDocumentModel searchDocumentModel,
+ RankProfileRegistry rankProfileRegistry,
+ FileRegistry fileRegistry,
+ DeployLogger deployLogger,
+ Optional<HostProvisioner> hostProvisioner,
+ DeployProperties properties,
+ Optional<ApplicationPackage> permanentApplicationPackage,
+ Optional<ConfigDefinitionRepo> configDefinitionRepo,
+ java.util.Optional<Model> previousModel,
+ Set<Rotation> rotations,
+ Collection<ModelImporter> modelImporters,
+ Zone zone,
+ QueryProfiles queryProfiles,
+ SemanticRules semanticRules,
+ Instant now,
+ Version wantedNodeVespaVersion) {
this.logger = deployLogger;
this.fileRegistry = fileRegistry;
this.rankProfileRegistry = rankProfileRegistry;
@@ -100,7 +114,8 @@ public class DeployState implements ConfigDefinitionStore {
this.zone = zone;
this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated
this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated
- this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR));
+ this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR),
+ modelImporters);
this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty);
this.wantedNodeVespaVersion = wantedNodeVespaVersion;
@@ -232,6 +247,7 @@ public class DeployState implements ConfigDefinitionStore {
private Optional<ConfigDefinitionRepo> configDefinitionRepo = Optional.empty();
private Optional<Model> previousModel = Optional.empty();
private Set<Rotation> rotations = new HashSet<>();
+ private Collection<ModelImporter> modelImporters = Collections.emptyList();
private Zone zone = Zone.defaultZone();
private Instant now = Instant.now();
private Version wantedNodeVespaVersion = Vtag.currentVersion;
@@ -281,6 +297,11 @@ public class DeployState implements ConfigDefinitionStore {
return this;
}
+ public Builder modelImporters(Collection<ModelImporter> modelImporters) {
+ this.modelImporters = modelImporters;
+ return this;
+ }
+
public Builder zone(Zone zone) {
this.zone = zone;
return this;
@@ -305,9 +326,23 @@ public class DeployState implements ConfigDefinitionStore {
QueryProfiles queryProfiles = new QueryProfilesBuilder().build(applicationPackage);
SemanticRules semanticRules = new SemanticRuleBuilder().build(applicationPackage);
SearchDocumentModel searchDocumentModel = createSearchDocumentModel(rankProfileRegistry, logger, queryProfiles, validationParameters);
- return new DeployState(applicationPackage, searchDocumentModel, rankProfileRegistry, fileRegistry, logger, hostProvisioner,
- properties, permanentApplicationPackage, configDefinitionRepo, previousModel, rotations,
- zone, queryProfiles, semanticRules, now, wantedNodeVespaVersion);
+ return new DeployState(applicationPackage,
+ searchDocumentModel,
+ rankProfileRegistry,
+ fileRegistry,
+ logger,
+ hostProvisioner,
+ properties,
+ permanentApplicationPackage,
+ configDefinitionRepo,
+ previousModel,
+ rotations,
+ modelImporters,
+ zone,
+ queryProfiles,
+ semanticRules,
+ now,
+ wantedNodeVespaVersion);
}
private SearchDocumentModel createSearchDocumentModel(RankProfileRegistry rankProfileRegistry,
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 78f61d7192d..a0cbc4271f3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -16,7 +16,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 0c5c7733dda..eb221a80638 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -12,7 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.derived.validation.Validation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import java.io.IOException;
import java.io.Writer;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index fcbfb47c597..c955cd9956b 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.RankingConstants;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.Search;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index a1b0e72051b..fd468c8eca7 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -9,7 +9,7 @@ import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
index 2fe2dacf2ce..f52eedc8d08 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
@@ -4,7 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import java.util.HashMap;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index fef198f7939..b4de5e925f6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -33,8 +33,8 @@ import com.yahoo.searchdefinition.derived.RankProfileList;
import com.yahoo.searchdefinition.processing.Processing;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import com.yahoo.vespa.model.ml.ConvertedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.vespa.config.ConfigKey;
import com.yahoo.vespa.config.ConfigPayload;
@@ -150,7 +150,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
this(configModelRegistry, deployState, true, null);
}
- private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) throws IOException, SAXException {
+ private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor)
+ throws IOException, SAXException {
super("vespamodel");
this.validationOverrides = deployState.validationOverrides();
configModelRegistry = new VespaConfigModelRegistry(configModelRegistry);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
index 2af9b297e9e..ff64b6753f9 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
@@ -21,6 +21,7 @@ import com.yahoo.config.model.deploy.DeployProperties;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.provision.Version;
import com.yahoo.config.provision.Zone;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import com.yahoo.vespa.config.VespaVersion;
import com.yahoo.vespa.model.application.validation.Validation;
@@ -29,6 +30,8 @@ import org.xml.sax.SAXException;
import java.io.IOException;
import java.time.Clock;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
@@ -41,19 +44,17 @@ public class VespaModelFactory implements ModelFactory {
private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName());
private final ConfigModelRegistry configModelRegistry;
+ private final Collection<ModelImporter> modelImporters;
private final Zone zone;
private final Clock clock;
private final Version version;
/** Creates a factory for vespa models for this version of the source */
@Inject
- public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) {
- this(Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), pluginRegistry, zone);
- }
-
- /** Creates a factory for vespa models of a particular version */
- public VespaModelFactory(Version version, ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) {
- this.version = version;
+ public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry,
+ ComponentRegistry<ModelImporter> modelImporters,
+ Zone zone) {
+ this.version = Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro);
List<ConfigModelBuilder> modelBuilders = new ArrayList<>();
for (ConfigModelPlugin plugin : pluginRegistry.allComponents()) {
if (plugin instanceof ConfigModelBuilder) {
@@ -61,6 +62,7 @@ public class VespaModelFactory implements ModelFactory {
}
}
this.configModelRegistry = new MapConfigModelRegistry(modelBuilders);
+ this.modelImporters = modelImporters.allComponents();
this.zone = zone;
this.clock = Clock.systemUTC();
}
@@ -79,6 +81,7 @@ public class VespaModelFactory implements ModelFactory {
} else {
this.configModelRegistry = configModelRegistry;
}
+ this.modelImporters = Collections.emptyList();
this.zone = Zone.defaultZone();
this.clock = clock;
}
@@ -137,6 +140,7 @@ public class VespaModelFactory implements ModelFactory {
.properties(createDeployProperties(modelContext.properties()))
.modelHostProvisioner(createHostProvisioner(modelContext))
.rotations(modelContext.properties().rotations())
+ .modelImporters(modelImporters)
.zone(zone)
.now(clock.instant())
.wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion());
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 30586b1e677..18f3cd6e088 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -19,7 +19,7 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -41,7 +41,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
-import java.io.Reader;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
index 07a36832094..ad1676ee7e1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.yolean.Exceptions;
import org.junit.Test;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 4df3add13c5..14847938f68 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.Iterator;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
index 8df3985fd24..4f46d813ba5 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index 150469cc928..551ef6968d2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -3,7 +3,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.yolean.Exceptions;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
index e507a6c48e4..79adb1cf6bf 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
@@ -6,8 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
-import com.yahoo.yolean.Exceptions;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.Optional;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 6a1e5b207c6..d97dfb0cbf9 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.List;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
index 5e649c2e551..dedfbce6ae8 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
@@ -4,9 +4,8 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.yolean.Exceptions;
-import org.junit.Ignore;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
index f39896f5779..437f43976b7 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
@@ -3,12 +3,11 @@ package com.yahoo.searchdefinition.derived;
import com.yahoo.document.DocumenttypesConfig;
import com.yahoo.document.config.DocumentmanagerConfig;
-import com.yahoo.io.IOUtils;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.configmodel.producers.DocumentManager;
import com.yahoo.vespa.configmodel.producers.DocumentTypes;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
index f4344c9b03c..409b4236a9c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
@@ -9,12 +9,9 @@ import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
-import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
-import java.io.IOException;
-
/**
* Tests deriving rank for files from search definitions
*
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
index 2bae285301c..f2dd16577d1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
@@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
index 723cd58a34a..93b366c8d08 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.io.File;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
index 8941b07101d..3a44c123f05 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
@@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
index d38bce04617..1c368ff6f10 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
@@ -6,13 +6,11 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
-import com.yahoo.searchdefinition.derived.Deriver;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
-import org.junit.Ignore;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
-import java.io.File;
+
import java.io.IOException;
import static org.junit.Assert.assertEquals;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index cff9abb08ed..7f590d4b230 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import com.google.common.collect.ImmutableList;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.path.Path;
@@ -10,7 +11,11 @@ import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import java.util.HashMap;
import java.util.List;
@@ -26,6 +31,9 @@ import static org.junit.Assert.assertEquals;
*/
class RankProfileSearchFixture {
+ private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new XGBoostImporter());
private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
private final QueryProfileRegistry queryProfileRegistry;
private Search search;
@@ -83,7 +91,8 @@ class RankProfileSearchFixture {
public RankProfile compileRankProfile(String rankProfile, Path applicationDir) {
RankProfile compiled = rankProfileRegistry.get(search, rankProfile)
- .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile()));
+ .compile(queryProfileRegistry,
+ new ImportedModels(applicationDir.toFile(), importers));
compiledRankProfiles.put(rankProfile, compiled);
return compiled;
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index fd048737b43..78dcae53f7e 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.io.IOException;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index 6e3a227e2a9..7eee18c1059 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels;
import org.junit.Test;
import java.util.List;
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
index 2ae629562d0..4d8bdb7a6c6 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
@@ -1,11 +1,17 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
+import com.google.common.collect.ImmutableList;
import com.yahoo.config.model.ApplicationPackageTester;
+import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.model.VespaModel;
@@ -26,6 +32,10 @@ import static org.junit.Assert.assertEquals;
*/
public class ImportedModelTester {
+ private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new XGBoostImporter());
+
private final String modelName;
private final Path applicationDir;
@@ -36,7 +46,10 @@ public class ImportedModelTester {
public VespaModel createVespaModel() {
try {
- return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app());
+ DeployState.Builder state = new DeployState.Builder();
+ state.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app());
+ state.modelImporters(importers);
+ return new VespaModel(state.build());
}
catch (SAXException | IOException e) {
throw new RuntimeException(e);
diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml
index 5f60be8c202..3dd6e0090c5 100644
--- a/configserver/src/main/resources/configserver-app/services.xml
+++ b/configserver/src/main/resources/configserver-app/services.xml
@@ -55,6 +55,7 @@
<preprocess:include file='config-models.xml' required='false' />
<preprocess:include file='node-repository.xml' required='false' />
<preprocess:include file='hosted-vespa/routing-status.xml' required='false' />
+ <preprocess:include file='model-integration.xml' required='true' />
<!-- TODO Vespa 7: Remove scoreboard.xml, replaced by metrics-packets.xml -->
<preprocess:include file='hosted-vespa/scoreboard.xml' required='false' />
diff --git a/container-dev/pom.xml b/container-dev/pom.xml
index 1b0f36e3adb..7d61882c085 100644
--- a/container-dev/pom.xml
+++ b/container-dev/pom.xml
@@ -104,18 +104,6 @@
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
</exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </exclusion>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -178,18 +166,6 @@
<groupId>xerces</groupId>
<artifactId>xercesImpl</artifactId>
</exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </exclusion>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/container-disc/pom.xml b/container-disc/pom.xml
index 62985192ad3..be4ac23a938 100644
--- a/container-disc/pom.xml
+++ b/container-disc/pom.xml
@@ -174,6 +174,7 @@
jdisc-security-filters-jar-with-dependencies.jar,
jdisc_http_service-jar-with-dependencies.jar,
model-evaluation-jar-with-dependencies.jar,
+ model-integration-jar-with-dependencies.jar,
vespaclient-container-plugin-jar-with-dependencies.jar,
vespa-athenz-jar-with-dependencies.jar,
security-utils-jar-with-dependencies.jar,
diff --git a/container/pom.xml b/container/pom.xml
index 32a7947d6d5..529f01e0a40 100644
--- a/container/pom.xml
+++ b/container/pom.xml
@@ -53,18 +53,6 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </exclusion>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
</exclusions>
</dependency>
</dependencies>
diff --git a/document/src/main/java/com/yahoo/document/Document.java b/document/src/main/java/com/yahoo/document/Document.java
index 85049baf8b0..222ebe29c6d 100644
--- a/document/src/main/java/com/yahoo/document/Document.java
+++ b/document/src/main/java/com/yahoo/document/Document.java
@@ -286,6 +286,7 @@ public class Document extends StructuredFieldValue {
/**
* Get JSON representation of the document root and its children contained in a JSON object
+ *
* @return JSON representation of document
*/
public String toJson() {
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
new file mode 100644
index 00000000000..28a00dcbdbc
--- /dev/null
+++ b/model-integration/pom.xml
@@ -0,0 +1,100 @@
+<?xml version="1.0"?>
+<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
+ http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>parent</artifactId>
+ <version>6-SNAPSHOT</version>
+ <relativePath>../parent/pom.xml</relativePath>
+ </parent>
+ <artifactId>model-integration</artifactId>
+ <version>6-SNAPSHOT</version>
+ <packaging>container-plugin</packaging>
+ <dependencies>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>component</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>vespajlib</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>searchlib</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>tensorflow</artifactId>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>bundle-plugin</artifactId>
+ <extensions>true</extensions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <compilerArgs>
+ <arg>-Xlint:rawtypes</arg>
+ <arg>-Xlint:unchecked</arg>
+ <arg>-Werror</arg>
+ </compilerArgs>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>com.github.os72</groupId>
+ <artifactId>protoc-jar-maven-plugin</artifactId>
+ <version>3.5.1.1</version>
+ <executions>
+ <execution>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <addSources>main</addSources>
+ <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory>
+ <inputDirectories>
+ <include>src/main/protobuf</include>
+ </inputDirectories>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/model-integration/src/main/config/model-integration.xml b/model-integration/src/main/config/model-integration.xml
new file mode 100644
index 00000000000..da45ce23575
--- /dev/null
+++ b/model-integration/src/main/config/model-integration.xml
@@ -0,0 +1,10 @@
+<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<!-- Component which can import some ml model.
+ This is included into the config server services.xml to enable it to translate
+ model pseudofeatures in ranking expressions during config mddel building.
+ It is provided as separate bundles instead of being included in the config model
+ because some of these (TensorFlow) includes
+ JNI code, and so can only exist in one instance in the server. -->
+<component id="ai.vespa.rankingexpression.importer.onnx.OnnxImporter" bundle="model-integration" />
+<component id="ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter" bundle="model-integration" />
+<component id="ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter" bundle="model-integration" />
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 3fe92440cae..81d0753ea4b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+package ai.vespa.rankingexpression.importer.onnx;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
@@ -28,9 +28,9 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class GraphImporter {
+class GraphImporter {
- public static IntermediateOperation mapOperation(Onnx.NodeProto node,
+ private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
IntermediateGraph graph) {
String nodeName = node.getName();
@@ -79,7 +79,7 @@ public class GraphImporter {
return op;
}
- public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
+ static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
Onnx.GraphProto onnxGraph = model.getGraph();
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
index e6bb5f40b3f..0418581d7b2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.onnx;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import onnx.Onnx;
import java.io.File;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index 18856d4a25f..6dd33c79852 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+package ai.vespa.rankingexpression.importer.onnx;
import com.google.protobuf.ByteString;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
@@ -17,9 +17,9 @@ import java.nio.FloatBuffer;
*
* @author lesters
*/
-public class TensorConverter {
+class TensorConverter {
- public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) {
+ static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) {
Values values = readValuesOf(tensorProto);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
for (int i = 0; i < values.size(); i++) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 715c55d8323..43ceaa747b7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+package ai.vespa.rankingexpression.importer.onnx;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
@@ -11,9 +11,9 @@ import onnx.Onnx;
*
* @author lesters
*/
-public class TypeConverter {
+class TypeConverter {
- public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
+ static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
@@ -30,11 +30,11 @@ public class TypeConverter {
}
}
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
+ static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
}
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (int i = 0; i < shape.getDimCount(); ++ i) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java
new file mode 100644
index 00000000000..9599cf8627c
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/package-info.java
@@ -0,0 +1,5 @@
+@ExportPackage
+
+package ai.vespa.rankingexpression.importer.onnx;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
index 89b75e8e3e2..73310c78cab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -20,15 +20,15 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class AttributeConverter implements IntermediateOperation.AttributeMap {
+class AttributeConverter implements IntermediateOperation.AttributeMap {
private final Map<String, AttrValue> attributeMap;
- public AttributeConverter(NodeDef node) {
+ private AttributeConverter(NodeDef node) {
attributeMap = node.getAttrMap();
}
- public static AttributeConverter convert(NodeDef node) {
+ static AttributeConverter convert(NodeDef node) {
return new AttributeConverter(node);
}
@@ -83,4 +83,5 @@ public class AttributeConverter implements IntermediateOperation.AttributeMap {
}
return Optional.empty();
}
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index e1b292f9e61..c012bc3c54f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
@@ -43,9 +43,9 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class GraphImporter {
+class GraphImporter {
- public static IntermediateOperation mapOperation(NodeDef node,
+ private static IntermediateOperation mapOperation(NodeDef node,
List<IntermediateOperation> inputs,
IntermediateGraph graph) {
String nodeName = node.getName();
@@ -112,7 +112,7 @@ public class GraphImporter {
return op;
}
- public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
+ static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
@@ -209,7 +209,7 @@ public class GraphImporter {
throw new IllegalArgumentException("Could not find node '" + name + "'");
}
- public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
+ static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
Session.Runner fetched = bundle.session().runner().fetch(name);
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
if (importedTensors.size() != 1)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index d2d0acfc964..4e67286ef09 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
@@ -26,7 +26,7 @@ public class TensorConverter {
return toVespaTensor(tfTensor, "d");
}
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
@@ -35,7 +35,7 @@ public class TensorConverter {
return builder.build();
}
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
+ static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
for (int i = 0; i < values.size(); i++) {
@@ -44,7 +44,7 @@ public class TensorConverter {
return builder.build();
}
- public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
+ static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
Values values = readValuesOf(tensorProto);
for (int i = 0; i < values.size(); ++i) {
@@ -71,7 +71,7 @@ public class TensorConverter {
return size;
}
- public static Long dimensionSize(TensorType.Dimension dim) {
+ private static Long dimensionSize(TensorType.Dimension dim) {
return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
}
@@ -182,8 +182,8 @@ public class TensorConverter {
}
private static abstract class ProtoValues extends Values {
- protected final TensorProto tensorProto;
- protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
+ final TensorProto tensorProto;
+ ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
}
private static class ProtoBoolValues extends ProtoValues {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index 7c18e04bae7..f8a25e4d94c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -1,8 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import org.tensorflow.SavedModelBundle;
import java.io.File;
@@ -47,7 +48,7 @@ public class TensorFlowImporter extends ModelImporter {
}
/** Imports a TensorFlow model */
- ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
+ public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
try {
IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
return convertIntermediateGraphToModel(graph, modelDir);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index 67ad1edc312..a5a506fdb6d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -1,6 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
@@ -15,9 +15,9 @@ import java.util.List;
*
* @author lesters
*/
-public class TypeConverter {
+class TypeConverter {
- public static void verifyType(NodeDef node, OrderedTensorType type) {
+ static void verifyType(NodeDef node, OrderedTensorType type) {
TensorShapeProto shape = tensorFlowShape(node);
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
@@ -50,11 +50,11 @@ public class TypeConverter {
return shapeList.get(0); // support multiple outputs?
}
- public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ static OrderedTensorType fromTensorFlowType(NodeDef node) {
return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
}
- public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
TensorShapeProto shape = tensorFlowShape(node);
for (int i = 0; i < shape.getDimCount(); ++ i) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
index 25bac27f315..b777ee07e58 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
@@ -1,9 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
@@ -16,7 +14,7 @@ import java.nio.charset.StandardCharsets;
*
* @author bratseth
*/
-public class VariableConverter {
+class VariableConverter {
/**
* Reads the tensor with the given TensorFlow name at the given model location,
@@ -24,7 +22,7 @@ public class VariableConverter {
* Note that order of dimensions in the tensor type does matter as the TensorFlow tensor
* tensor dimensions are implicitly ordered.
*/
- public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
+ static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
bundle),
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java
new file mode 100644
index 00000000000..0840e584d25
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/package-info.java
@@ -0,0 +1,4 @@
+@ExportPackage
+package ai.vespa.rankingexpression.importer.tensorflow;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index 725f319a839..e87dd265d50 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -1,8 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.xgboost;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost.XGBoostParser;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import java.io.File;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
index fef8bfec81d..2b215f816f5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
+package ai.vespa.rankingexpression.importer.xgboost;
import java.io.File;
import java.io.IOException;
@@ -13,7 +13,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
/**
* @author grace-lam
*/
-public class XGBoostParser {
+class XGBoostParser {
private List<XGBoostTree> xgboostTrees;
@@ -24,7 +24,7 @@ public class XGBoostParser {
* @throws JsonProcessingException Fails JSON parsing.
* @throws IOException Fails file reading.
*/
- public XGBoostParser(String filePath) throws JsonProcessingException, IOException {
+ XGBoostParser(String filePath) throws JsonProcessingException, IOException {
this.xgboostTrees = new ArrayList<>();
ObjectMapper mapper = new ObjectMapper();
JsonNode forestNode = mapper.readTree(new File(filePath));
@@ -38,7 +38,7 @@ public class XGBoostParser {
*
* @return Vespa ranking expressions.
*/
- public String toRankingExpression() {
+ String toRankingExpression() {
StringBuilder ret = new StringBuilder();
for (int i = 0; i < xgboostTrees.size(); i++) {
ret.append(treeToRankExp(xgboostTrees.get(i)));
@@ -55,7 +55,7 @@ public class XGBoostParser {
* @param node XGBoost tree node to convert.
* @return Vespa ranking expression for input node.
*/
- public String treeToRankExp(XGBoostTree node) {
+ private String treeToRankExp(XGBoostTree node) {
if (node.isLeaf()) {
return Double.toString(node.getLeaf());
} else {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java
index 6bbc9abe8ae..e32e0f1eab5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostTree.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
+package ai.vespa.rankingexpression.importer.xgboost;
import java.util.List;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java
new file mode 100644
index 00000000000..d310de9041b
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/package-info.java
@@ -0,0 +1,4 @@
+@ExportPackage
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/searchlib/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto
index dc6542867e0..dc6542867e0 100644
--- a/searchlib/src/main/protobuf/onnx.proto
+++ b/model-integration/src/main/protobuf/onnx.proto
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 6d0ee0a906d..a71d7a42551 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -1,11 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.onnx;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -21,7 +23,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
@@ -55,8 +57,8 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testComparisonBetweenOnnxAndTensorflow() {
- String tfModelPath = "src/test/files/integration/tensorflow/mnist_softmax/saved";
- String onnxModelPath = "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx";
+ String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved";
+ String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx";
Tensor argument = placeholderArgument();
Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
index e325c3d11b4..acd649d985b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -16,7 +16,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test",
- "src/test/files/integration/tensorflow/batch_norm/saved");
+ "src/test/models/tensorflow/batch_norm/saved");
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
index 5b0cca6a940..a878b284b2c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BlogEvaluationBenchmark.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -17,8 +17,6 @@ import org.tensorflow.Session;
import java.nio.FloatBuffer;
import java.util.List;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTensorFlowModel.contextFrom;
-
/**
* Microbenchmark of imported ML model evaluation.
*
@@ -26,13 +24,13 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.TestableTenso
*/
public class BlogEvaluationBenchmark {
- static final String modelDir = "src/test/files/integration/tensorflow/blog/saved";
+ static final String modelDir = "src/test/models/tensorflow/blog/saved";
public static void main(String[] args) throws ParseException {
SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
ImportedModel model = new TensorFlowImporter().importModel("blog", modelDir, tensorFlowModel);
- Context context = contextFrom(model);
+ Context context = TestableTensorFlowModel.contextFrom(model);
Tensor u = generateInputTensor();
Tensor d = generateInputTensor();
context.put("input_u", new TensorValue(u));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index 8ca5a9a7888..6e58761e5ce 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.tensor.TensorType;
+import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -17,18 +18,18 @@ public class DropoutImportTestCase {
@Test
public void testDropoutImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/dropout/saved");
// Check required functions
- assertEquals(1, model.get().inputs().size());
+ Assert.assertEquals(1, model.get().inputs().size());
assertTrue(model.get().inputs().containsKey("X"));
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().inputs().get("X"));
+ Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().inputs().get("X"));
ImportedModel.Signature signature = model.get().signature("serving_default");
- assertEquals("Has skipped outputs",
- 0, model.get().signature("serving_default").skippedOutputs().size());
+ Assert.assertEquals("Has skipped outputs",
+ 0, model.get().signature("serving_default").skippedOutputs().size());
ExpressionFunction function = signature.outputExpression("y");
assertNotNull(function);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
index 3d8d5d5a570..b338f46fb4d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
@@ -1,8 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -15,11 +16,11 @@ public class MnistImportTestCase {
@Test
public void testMnistImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist/saved");
ImportedModel.Signature signature = model.get().signature("serving_default");
- assertEquals("Has skipped outputs",
- 0, model.get().signature("serving_default").skippedOutputs().size());
+ Assert.assertEquals("Has skipped outputs",
+ 0, model.get().signature("serving_default").skippedOutputs().size());
ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
index feba40601e3..7e8cbef8ada 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
@@ -1,10 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -18,10 +19,10 @@ public class TensorFlowMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/mnist_softmax/saved");
// Check constants
- assertEquals(2, model.get().largeConstants().size());
+ Assert.assertEquals(2, model.get().largeConstants().size());
Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
assertNotNull(constant0);
@@ -36,16 +37,16 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals(10, constant1.size());
// Check (provided) functions
- assertEquals(0, model.get().functions().size());
+ Assert.assertEquals(0, model.get().functions().size());
// Check required functions
- assertEquals(1, model.get().inputs().size());
+ Assert.assertEquals(1, model.get().inputs().size());
assertTrue(model.get().inputs().containsKey("Placeholder"));
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().inputs().get("Placeholder"));
+ Assert.assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.get().inputs().get("Placeholder"));
// Check signatures
- assertEquals(1, model.get().signatures().size());
+ Assert.assertEquals(1, model.get().signatures().size());
ImportedModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index fbe7c5fac63..dbed537885e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java
index aabbdf33c9d..c9fffe143b4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverterTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package ai.vespa.rankingexpression.importer.tensorflow;
import org.junit.Test;
@@ -11,7 +11,7 @@ public class VariableConverterTestCase {
@Test
public void testConversion() {
- byte[] converted = VariableConverter.importVariable("src/test/files/integration/tensorflow/mnist_softmax/saved",
+ byte[] converted = VariableConverter.importVariable("src/test/models/tensorflow/mnist_softmax/saved",
"Variable_1",
"tensor(d0[10],d1[1])");
assertEquals("{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"1\",\"d1\":\"0\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"2\",\"d1\":\"0\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"3\",\"d1\":\"0\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"4\",\"d1\":\"0\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"5\",\"d1\":\"0\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"6\",\"d1\":\"0\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"7\",\"d1\":\"0\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"8\",\"d1\":\"0\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"9\",\"d1\":\"0\"},\"value\":-0.2573896050453186}]}",
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
new file mode 100644
index 00000000000..48c7f5bee19
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -0,0 +1,28 @@
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author bratseth
+ */
+public class XGBoostImportTestCase {
+
+ @Test
+ public void testXGBoost() {
+ ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json");
+ assertTrue("All inputs are scalar", model.inputs().isEmpty());
+ assertEquals(1, model.expressions().size());
+ System.out.println(model.expressions().keySet());
+ RankingExpression expression = model.expressions().get("test");
+ assertNotNull(expression);
+ assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)",
+ expression.getRoot().toString());
+ }
+
+}
diff --git a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx
index a86019bf53a..a86019bf53a 100644
--- a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx
+++ b/model-integration/src/test/models/onnx/mnist_softmax/mnist_softmax.onnx
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py
index bc6ea13ebc1..bc6ea13ebc1 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/batch_normalization_mnist.py
+++ b/model-integration/src/test/models/tensorflow/batch_norm/batch_normalization_mnist.py
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt
index f3ce68a1cbd..f3ce68a1cbd 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001
index 875e8361e10..875e8361e10 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index
index 46c7b258cf5..46c7b258cf5 100644
--- a/searchlib/src/test/files/integration/tensorflow/batch_norm/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/batch_norm/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt
index a669e69b709..a669e69b709 100644
--- a/searchlib/src/test/files/integration/tensorflow/blog/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/blog/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001
index 1efd102aef9..1efd102aef9 100644
--- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index
index 56c60dbe529..56c60dbe529 100644
--- a/searchlib/src/test/files/integration/tensorflow/blog/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/blog/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py b/model-integration/src/test/models/tensorflow/dropout/dropout.py
index 42c15cd2812..42c15cd2812 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/dropout.py
+++ b/model-integration/src/test/models/tensorflow/dropout/dropout.py
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt
index ad431f0460d..ad431f0460d 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/dropout/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001
index 000c9b3a7b5..000c9b3a7b5 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index
index 9492ef4bde2..9492ef4bde2 100644
--- a/searchlib/src/test/files/integration/tensorflow/dropout/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/dropout/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt
index eb926836576..eb926836576 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/mnist/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001
index a7ca01888c7..a7ca01888c7 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index
index 7989c109a3a..7989c109a3a 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/mnist/saved/variables/variables.index
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py
index 86a17e81f8f..86a17e81f8f 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist/simple_mnist.py
+++ b/model-integration/src/test/models/tensorflow/mnist/simple_mnist.py
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
index 07a9fa4a213..07a9fa4a213 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt
index 8100dfd594d..8100dfd594d 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/saved_model.pbtxt
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
index 8474aa0a04c..8474aa0a04c 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index
index cfcdac20409..cfcdac20409 100644
--- a/searchlib/src/test/files/integration/tensorflow/mnist_softmax/saved/variables/variables.index
+++ b/model-integration/src/test/models/tensorflow/mnist_softmax/saved/variables/variables.index
Binary files differ
diff --git a/model-integration/src/test/models/xgboost/xgboost.2.2.json b/model-integration/src/test/models/xgboost/xgboost.2.2.json
new file mode 100644
index 00000000000..f8949b47e52
--- /dev/null
+++ b/model-integration/src/test/models/xgboost/xgboost.2.2.json
@@ -0,0 +1,19 @@
+[
+ { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 1.71218 },
+ { "nodeid": 4, "leaf": -1.70044 }
+ ]},
+ { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [
+ { "nodeid": 5, "leaf": -1.94071 },
+ { "nodeid": 6, "leaf": 1.85965 }
+ ]}
+ ]},
+ { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 0.784718 },
+ { "nodeid": 4, "leaf": -0.96853 }
+ ]},
+ { "nodeid": 2, "leaf": -6.23624 }
+ ]}
+] \ No newline at end of file
diff --git a/pom.xml b/pom.xml
index f749adbdabf..386aa5bc5a2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -94,6 +94,7 @@
<module>messagebus</module>
<module>metrics</module>
<module>model-evaluation</module>
+ <module>model-integration</module>
<module>node-repository</module>
<module>node-admin</module>
<module>node-maintainer</module>
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index f1be4e96269..e0ce822e593 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -35,18 +35,6 @@
<version>${project.version}</version>
</dependency>
<dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>proto</artifactId>
- </dependency>
- <dependency>
- <groupId>org.tensorflow</groupId>
- <artifactId>tensorflow</artifactId>
- </dependency>
- <dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>provided</scope>
@@ -117,26 +105,6 @@
</execution>
</executions>
</plugin>
- <plugin>
- <groupId>com.github.os72</groupId>
- <artifactId>protoc-jar-maven-plugin</artifactId>
- <version>3.5.1.1</version>
- <executions>
- <execution>
- <phase>generate-sources</phase>
- <goals>
- <goal>run</goal>
- </goals>
- <configuration>
- <addSources>main</addSources>
- <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory>
- <inputDirectories>
- <include>src/main/protobuf</include>
- </inputDirectories>
- </configuration>
- </execution>
- </executions>
- </plugin>
</plugins>
</build>
</project>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java
index 59ec66b7209..854a5202916 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
@@ -10,13 +9,11 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.Optional;
import java.util.regex.Pattern;
@@ -64,8 +61,7 @@ public class ImportedModel {
/**
* Returns an immutable map of the small constants of this.
- * These should have sizes up to a few kb at most, and correspond to constant
- * values given in the TensorFlow or ONNX source.
+ * These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
@@ -93,18 +89,18 @@ public class ImportedModel {
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
/** Returns the given signature. If it does not already exist it is added to this. */
- Signature signature(String name) {
+ public Signature signature(String name) {
return signatures.computeIfAbsent(name, Signature::new);
}
/** Convenience method for returning a default signature */
- Signature defaultSignature() { return signature(defaultSignatureName); }
+ public Signature defaultSignature() { return signature(defaultSignatureName); }
- void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void function(String name, RankingExpression expression) { functions.put(name, expression); }
+ public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
+ public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ public void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ public void function(String name, RankingExpression expression) { functions.put(name, expression); }
/**
* Returns all the output expressions of this indexed by name. The names consist of one or two parts
@@ -116,9 +112,9 @@ public class ImportedModel {
List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
+ expressions.add(new Pair<>(signatureEntry.getKey() + "" + outputEntry.getKey(),
signatureEntry.getValue().outputExpression(outputEntry.getKey())
- .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
+ .withName(signatureEntry.getKey() + "" + outputEntry.getKey())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
new ExpressionFunction(signatureEntry.getKey(),
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java
index 40d1ca8030a..55f1eef741c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
@@ -24,27 +23,26 @@ public class ImportedModels {
/** All imported models, indexed by their names */
private final ImmutableMap<String, ImportedModel> importedModels;
- private static final ImmutableList<ModelImporter> importers =
- ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), new XGBoostImporter());
-
/** Create a null imported models */
public ImportedModels() {
importedModels = ImmutableMap.of();
}
- public ImportedModels(File modelsDirectory) {
+ public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
Map<String, ImportedModel> models = new HashMap<>();
// Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, models);
+ importRecursively(modelsDirectory, models, importers);
importedModels = ImmutableMap.copyOf(models);
}
- private static void importRecursively(File dir, Map<String, ImportedModel> models) {
+ private static void importRecursively(File dir,
+ Map<String, ImportedModel> models,
+ Collection<ModelImporter> importers) {
if ( ! dir.isDirectory()) return;
Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
- Optional<ModelImporter> importer = findImporterOf(child);
+ Optional<ModelImporter> importer = findImporterOf(child, importers);
if (importer.isPresent()) {
String name = toName(child);
ImportedModel existing = models.get(name);
@@ -54,12 +52,12 @@ public class ImportedModels {
models.put(name, importer.get().importModel(name, child));
}
else {
- importRecursively(child, models);
+ importRecursively(child, models, importers);
}
});
}
- private static Optional<ModelImporter> findImporterOf(File path) {
+ private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
}
@@ -93,7 +91,7 @@ public class ImportedModels {
}
private static Path stripFileEnding(Path path) {
- int dotIndex = path.last().lastIndexOf(".");
+ int dotIndex = path.last().lastIndexOf("");
if (dotIndex <= 0) return path;
return path.withLast(path.last().substring(0, dotIndex));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java
index 481b7f9397a..1b6494e8ce8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java
@@ -1,11 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -47,7 +45,7 @@ public abstract class ModelImporter {
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
*/
- static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
+ public static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
ImportedModel model = new ImportedModel(graph.name(), modelSource);
graph.optimize();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
index 7fc2aae87d1..ab6a80a193a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
index 1b8c62fe0e9..1f479cd2e0b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 0eff8e8bc08..aab50f422be 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -180,7 +180,7 @@ public abstract class IntermediateOperation {
/**
* An interface mapping operation attributes to Vespa Values.
- * Adapter for differences in ONNX/TensorFlow.
+ * Adapter for differences in different model types.
*/
public interface AttributeMap {
Optional<Value> get(String key);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
index 8413ed74118..2d401469e40 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 95a77c07590..6ce9abf2ec9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -3,8 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index e91c2305f7d..ff87412396d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -2,8 +2,8 @@
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -24,8 +24,6 @@ import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-
public class Reshape extends IntermediateOperation {
public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
@@ -52,7 +50,7 @@ public class Reshape extends IntermediateOperation {
int size = cell.getValue().intValue();
if (size < 0) {
size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- tensorSize(inputType.type()).intValue();
+ OrderedTensorType.tensorSize(inputType.type()).intValue();
}
outputTypeBuilder.add(TensorType.Dimension.indexed(
String.format("%s_%d", vespaName(), dimensionIndex), size));
@@ -82,7 +80,7 @@ public class Reshape extends IntermediateOperation {
}
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!tensorSize(inputType).equals(tensorSize(outputType))) {
+ if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java
index 1530754cc43..bb55ed768a6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
/**
- * ONNX integration
+ * Model integration
*/
@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java
index b3dafff621c..4bd28a74d6f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamerTest.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import org.junit.Test;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java
index 55e1d234782..5118081637e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorTypeTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import org.junit.Test;
diff --git a/standalone-container/vespa-standalone-container.spec b/standalone-container/vespa-standalone-container.spec
index f051142c516..6143df6a446 100644
--- a/standalone-container/vespa-standalone-container.spec
+++ b/standalone-container/vespa-standalone-container.spec
@@ -54,6 +54,7 @@ declare -a modules=(
jdisc_core
jdisc_http_service
model-evaluation
+ model-integration
security-utils
simplemetrics
standalone-container