aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java37
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java23
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModels.java (renamed from model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java)68
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java16
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java18
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java10
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java8
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java8
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java6
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java31
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java18
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java14
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java18
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java12
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java8
-rw-r--r--model-integration/pom.xml6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java106
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java13
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java11
34 files changed, 278 insertions, 235 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java
new file mode 100644
index 00000000000..54cdf807878
--- /dev/null
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java
@@ -0,0 +1,37 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.config.model.api;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * An imported function of an imported machine-learned model
+ *
+ * @author bratseth
+ */
+public class ImportedMlFunction {
+
+ private final String name;
+ private final List<String> arguments;
+ private final Map<String, String> argumentTypes;
+ private final String expression;
+ private final Optional<String> returnType;
+
+ public ImportedMlFunction(String name, List<String> arguments, String expression,
+ Map<String, String> argumentTypes, Optional<String> returnType) {
+ this.name = name;
+ this.arguments = Collections.unmodifiableList(arguments);
+ this.expression = expression;
+ this.argumentTypes = Collections.unmodifiableMap(argumentTypes);
+ this.returnType = returnType;
+ }
+
+ public String name() { return name; }
+ public List<String> arguments() { return arguments; }
+ public Map<String, String> argumentTypes() { return argumentTypes; }
+ public String expression() { return expression; }
+ public Optional<String> returnType() { return returnType; }
+
+}
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java
new file mode 100644
index 00000000000..078e4c239d6
--- /dev/null
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java
@@ -0,0 +1,23 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.config.model.api;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Config model view of an imported machine-learned model.
+ *
+ * @author bratseth
+ */
+public interface ImportedMlModel {
+
+ String name();
+ String source();
+ Optional<String> inputTypeSpec(String input);
+ Map<String, String> smallConstants();
+ Map<String, String> largeConstants();
+ Map<String, String> functions();
+ List<ImportedMlFunction> outputExpressions();
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModels.java
index bfdaaca1dd7..aeef81788b8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModels.java
@@ -1,12 +1,12 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.rankingexpression.importer;
+package com.yahoo.config.model.api;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
import java.io.File;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@@ -18,34 +18,51 @@ import java.util.Optional;
*
* @author bratseth
*/
-public class ImportedModels {
+public class ImportedMlModels {
/** All imported models, indexed by their names */
- private final ImmutableMap<String, ImportedModel> importedModels;
+ private final Map<String, ImportedMlModel> importedModels;
/** Create a null imported models */
- public ImportedModels() {
- importedModels = ImmutableMap.of();
+ public ImportedMlModels() {
+ importedModels = Collections.emptyMap();
}
- public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
- Map<String, ImportedModel> models = new HashMap<>();
+ public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) {
+ Map<String, ImportedMlModel> models = new HashMap<>();
// Find all subdirectories recursively which contains a model we can read
importRecursively(modelsDirectory, models, importers);
- importedModels = ImmutableMap.copyOf(models);
+ importedModels = Collections.unmodifiableMap(models);
+ }
+
+ /**
+ * Returns the model at the given location in the application package.
+ *
+ * @param modelPath the path to this model (file or directory, depending on model type)
+ * under the application package, both from the root or relative to the
+ * models directory works
+ * @return the model at this path or null if none
+ */
+ public ImportedMlModel get(File modelPath) {
+ return importedModels.get(toName(modelPath));
+ }
+
+ /** Returns an immutable collection of all the imported models */
+ public Collection<ImportedMlModel> all() {
+ return importedModels.values();
}
private static void importRecursively(File dir,
- Map<String, ImportedModel> models,
- Collection<ModelImporter> importers) {
+ Map<String, ImportedMlModel> models,
+ Collection<MlModelImporter> importers) {
if ( ! dir.isDirectory()) return;
Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
- Optional<ModelImporter> importer = findImporterOf(child, importers);
+ Optional<MlModelImporter> importer = findImporterOf(child, importers);
if (importer.isPresent()) {
String name = toName(child);
- ImportedModel existing = models.get(name);
+ ImportedMlModel existing = models.get(name);
if (existing != null)
throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
" both resolve to the model name '" + name + "'");
@@ -57,33 +74,10 @@ public class ImportedModels {
});
}
- private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
+ private static Optional<MlModelImporter> findImporterOf(File path, Collection<MlModelImporter> importers) {
return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
}
- /**
- * Returns the model at the given location in the application package.
- *
- * @param modelPath the path to this model (file or directory, depending on model type)
- * under the application package, both from the root or relative to the
- * models directory works
- * @return the model at this path or null if none
- */
- // CFG
- public ImportedModel get(File modelPath) {
- return importedModels.get(toName(modelPath));
- }
-
- public ImportedModel get(String modelName) {
- return importedModels.get(modelName);
- }
-
- /** Returns an immutable collection of all the imported models */
- // CFG
- public Collection<ImportedModel> all() {
- return importedModels.values();
- }
-
private static String toName(File modelFile) {
Path modelPath = Path.fromString(modelFile.toString());
if (modelFile.isFile())
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java b/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java
new file mode 100644
index 00000000000..d24eeb2d55a
--- /dev/null
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java
@@ -0,0 +1,16 @@
+package com.yahoo.config.model.api;
+
+import java.io.File;
+
+/**
+ * Config model view of a machine-learned model importer
+ *
+ * @author bratseth
+ */
+public interface MlModelImporter {
+
+ boolean canImport(String modelPath);
+
+ ImportedMlModel importModel(String modelName, File modelPath);
+
+}
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index abca92eeb79..06b2452dbd4 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -1,8 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.config.model.deploy;
-import ai.vespa.rankingexpression.importer.ImportedModels;
-import ai.vespa.rankingexpression.importer.ModelImporter;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.component.Version;
import com.yahoo.component.Vtag;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -11,6 +10,7 @@ import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.config.application.api.UnparsedConfigDefinition;
import com.yahoo.config.model.api.ConfigDefinitionRepo;
import com.yahoo.config.model.api.HostProvisioner;
+import com.yahoo.config.model.api.MlModelImporter;
import com.yahoo.config.model.api.Model;
import com.yahoo.config.model.api.ValidationParameters;
import com.yahoo.config.model.application.provider.BaseDeployLogger;
@@ -68,7 +68,7 @@ public class DeployState implements ConfigDefinitionStore {
private final Zone zone;
private final QueryProfiles queryProfiles;
private final SemanticRules semanticRules;
- private final ImportedModels importedModels;
+ private final ImportedMlModels importedModels;
private final ValidationOverrides validationOverrides;
private final Version wantedNodeVespaVersion;
private final Instant now;
@@ -93,7 +93,7 @@ public class DeployState implements ConfigDefinitionStore {
Optional<ConfigDefinitionRepo> configDefinitionRepo,
java.util.Optional<Model> previousModel,
Set<Rotation> rotations,
- Collection<ModelImporter> modelImporters,
+ Collection<MlModelImporter> modelImporters,
Zone zone,
QueryProfiles queryProfiles,
SemanticRules semanticRules,
@@ -114,8 +114,8 @@ public class DeployState implements ConfigDefinitionStore {
this.zone = zone;
this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated
this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated
- this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR),
- modelImporters);
+ this.importedModels = new ImportedMlModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR),
+ modelImporters);
this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty);
this.wantedNodeVespaVersion = wantedNodeVespaVersion;
@@ -230,7 +230,7 @@ public class DeployState implements ConfigDefinitionStore {
public SemanticRules getSemanticRules() { return semanticRules; }
/** The (machine learned) models imported from the models/ directory, as an unmodifiable map indexed by model name */
- public ImportedModels getImportedModels() { return importedModels; }
+ public ImportedMlModels getImportedModels() { return importedModels; }
public Version getWantedNodeVespaVersion() { return wantedNodeVespaVersion; }
@@ -247,7 +247,7 @@ public class DeployState implements ConfigDefinitionStore {
private Optional<ConfigDefinitionRepo> configDefinitionRepo = Optional.empty();
private Optional<Model> previousModel = Optional.empty();
private Set<Rotation> rotations = new HashSet<>();
- private Collection<ModelImporter> modelImporters = Collections.emptyList();
+ private Collection<MlModelImporter> modelImporters = Collections.emptyList();
private Zone zone = Zone.defaultZone();
private Instant now = Instant.now();
private Version wantedNodeVespaVersion = Vtag.currentVersion;
@@ -297,7 +297,7 @@ public class DeployState implements ConfigDefinitionStore {
return this;
}
- public Builder modelImporters(Collection<ModelImporter> modelImporters) {
+ public Builder modelImporters(Collection<MlModelImporter> modelImporters) {
this.modelImporters = modelImporters;
return this;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index e434babdeb4..7c0b90c35fa 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.search.query.profile.types.FieldDescription;
@@ -647,7 +647,7 @@ public class RankProfile implements Serializable, Cloneable {
* Returns a copy of this where the content is optimized for execution.
* Compiled profiles should never be modified.
*/
- public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedModels importedModels) {
+ public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedMlModels importedModels) {
try {
RankProfile compiled = this.clone();
compiled.compileThis(queryProfiles, importedModels);
@@ -658,7 +658,7 @@ public class RankProfile implements Serializable, Cloneable {
}
}
- private void compileThis(QueryProfileRegistry queryProfiles, ImportedModels importedModels) {
+ private void compileThis(QueryProfileRegistry queryProfiles, ImportedMlModels importedModels) {
checkNameCollisions(getFunctions(), getConstants());
ExpressionTransforms expressionTransforms = new ExpressionTransforms();
@@ -688,7 +688,7 @@ public class RankProfile implements Serializable, Cloneable {
private Map<String, RankingExpressionFunction> compileFunctions(Supplier<Map<String, RankingExpressionFunction>> functions,
QueryProfileRegistry queryProfiles,
- ImportedModels importedModels,
+ ImportedMlModels importedModels,
Map<String, RankingExpressionFunction> inlineFunctions,
ExpressionTransforms expressionTransforms) {
Map<String, RankingExpressionFunction> compiledFunctions = new LinkedHashMap<>();
@@ -716,7 +716,7 @@ public class RankProfile implements Serializable, Cloneable {
private RankingExpression compile(RankingExpression expression,
QueryProfileRegistry queryProfiles,
- ImportedModels importedModels,
+ ImportedMlModels importedModels,
Map<String, Value> constants,
Map<String, RankingExpressionFunction> inlineFunctions,
ExpressionTransforms expressionTransforms) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 271c335cd1f..7dc4b815da6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.derived;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.config.ConfigInstance;
import com.yahoo.config.model.application.provider.BaseDeployLogger;
import com.yahoo.config.application.api.DeployLogger;
@@ -49,7 +49,7 @@ public class DerivedConfiguration {
public DerivedConfiguration(Search search,
RankProfileRegistry rankProfileRegistry,
QueryProfileRegistry queryProfiles,
- ImportedModels importedModels) {
+ ImportedMlModels importedModels) {
this(search, new BaseDeployLogger(), rankProfileRegistry, queryProfiles, importedModels);
}
@@ -68,7 +68,7 @@ public class DerivedConfiguration {
DeployLogger deployLogger,
RankProfileRegistry rankProfileRegistry,
QueryProfileRegistry queryProfiles,
- ImportedModels importedModels) {
+ ImportedMlModels importedModels) {
Validator.ensureNotNull("Search definition", search);
this.search = search;
if ( ! search.isDocumentsOnly()) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 51e0c1d2f47..4c117e44857 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.derived;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
@@ -45,7 +45,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
AttributeFields attributeFields,
RankProfileRegistry rankProfileRegistry,
QueryProfileRegistry queryProfiles,
- ImportedModels importedModels) {
+ ImportedMlModels importedModels) {
setName(search == null ? "default" : search.getName());
this.rankingConstants = rankingConstants;
deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields);
@@ -53,7 +53,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry,
QueryProfileRegistry queryProfiles,
- ImportedModels importedModels,
+ ImportedMlModels importedModels,
Search search,
AttributeFields attributeFields) {
if (search != null) { // profiles belonging to a search have a default profile
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 97b6def07ab..b7f515cedd4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.derived;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.google.common.collect.ImmutableList;
import com.yahoo.collections.Pair;
import com.yahoo.compress.Compressor;
@@ -50,7 +50,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
/**
* Creates a raw rank profile from the given rank profile
*/
- public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) {
+ public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) {
this.name = rankProfile.getName();
compressedProperties = compress(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields).derive());
}
@@ -148,7 +148,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
/**
* Creates a raw rank profile from the given rank profile
*/
- public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) {
+ public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) {
RankProfile compiled = rankProfile.compile(queryProfiles, importedModels);
attributeTypes = compiled.getAttributeTypes();
queryFeatureTypes = compiled.getQueryFeatureTypes();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
index 9043bf966a3..f20298cfe1a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -19,13 +19,13 @@ public class RankProfileTransformContext extends TransformContext {
private final RankProfile rankProfile;
private final QueryProfileRegistry queryProfiles;
- private final ImportedModels importedModels;
+ private final ImportedMlModels importedModels;
private final Map<String, RankProfile.RankingExpressionFunction> inlineFunctions;
private final Map<String, String> rankProperties = new HashMap<>();
public RankProfileTransformContext(RankProfile rankProfile,
QueryProfileRegistry queryProfiles,
- ImportedModels importedModels,
+ ImportedMlModels importedModels,
Map<String, Value> constants,
Map<String, RankProfile.RankingExpressionFunction> inlineFunctions) {
super(constants);
@@ -37,7 +37,7 @@ public class RankProfileTransformContext extends TransformContext {
public RankProfile rankProfile() { return rankProfile; }
public QueryProfileRegistry queryProfiles() { return queryProfiles; }
- public ImportedModels importedModels() { return importedModels; }
+ public ImportedMlModels importedModels() { return importedModels; }
public Map<String, RankProfile.RankingExpressionFunction> inlineFunctions() { return inlineFunctions; }
public Map<String, String> rankProperties() { return rankProperties; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index a01e7cffd84..58fc08d15e7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -1,8 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model;
-import ai.vespa.rankingexpression.importer.ImportedModel;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModel;
import com.yahoo.config.ConfigBuilder;
import com.yahoo.config.ConfigInstance;
import com.yahoo.config.ConfigInstance.Builder;
@@ -20,6 +19,7 @@ import com.yahoo.config.model.ConfigModelRepo;
import com.yahoo.config.model.NullConfigModelRegistry;
import com.yahoo.config.model.api.FileDistribution;
import com.yahoo.config.model.api.HostInfo;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.config.model.api.Model;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.producer.AbstractConfigProducer;
@@ -217,11 +217,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
* Creates a rank profile not attached to any search definition, for each imported model in the application package,
* and adds it to the given rank profile registry.
*/
- private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedModels importedModels,
+ private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedMlModels importedModels,
RankProfileRegistry rankProfileRegistry,
QueryProfiles queryProfiles) {
if ( ! importedModels.all().isEmpty()) { // models/ directory is available
- for (ImportedModel model : importedModels.all()) {
+ for (ImportedMlModel model : importedModels.all()) {
RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
index 8fa9dfa9e91..954f20f36c0 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model;
-import ai.vespa.rankingexpression.importer.ModelImporter;
import com.google.inject.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -11,6 +10,7 @@ import com.yahoo.config.model.NullConfigModelRegistry;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.ConfigModelPlugin;
import com.yahoo.config.model.api.HostProvisioner;
+import com.yahoo.config.model.api.MlModelImporter;
import com.yahoo.config.model.api.Model;
import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.api.ModelCreateResult;
@@ -44,7 +44,7 @@ public class VespaModelFactory implements ModelFactory {
private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName());
private final ConfigModelRegistry configModelRegistry;
- private final Collection<ModelImporter> modelImporters;
+ private final Collection<MlModelImporter> modelImporters;
private final Zone zone;
private final Clock clock;
private final Version version;
@@ -52,7 +52,7 @@ public class VespaModelFactory implements ModelFactory {
/** Creates a factory for vespa models for this version of the source */
@Inject
public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry,
- ComponentRegistry<ModelImporter> modelImporters,
+ ComponentRegistry<MlModelImporter> modelImporters,
Zone zone) {
this.version = Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro);
List<ConfigModelBuilder> modelBuilders = new ArrayList<>();
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 259ac5227ae..c834bea7be2 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -1,11 +1,12 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;
-import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.config.model.api.ImportedMlFunction;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.api.ImportedMlModel;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
@@ -70,12 +71,12 @@ public class ConvertedModel {
private final ImmutableMap<String, ExpressionFunction> expressions;
/** The source importedModel, or empty if this was created from a stored converted model */
- private final Optional<ImportedModel> sourceModel;
+ private final Optional<ImportedMlModel> sourceModel;
private ConvertedModel(ModelName modelName,
String modelDescription,
Map<String, ExpressionFunction> expressions,
- Optional<ImportedModel> sourceModel) {
+ Optional<ImportedMlModel> sourceModel) {
this.modelName = modelName;
this.modelDescription = modelDescription;
this.expressions = ImmutableMap.copyOf(expressions);
@@ -90,13 +91,13 @@ public class ConvertedModel {
* @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory
*/
public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) {
- ImportedModel sourceModel = // TODO: Convert to name here, make sure its done just one way
+ ImportedMlModel sourceModel = // TODO: Convert to name here, make sure its done just one way
context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath));
ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile);
if (sourceModel == null && ! new ModelStore(context.rankProfile().applicationPackage(), modelName).exists())
throw new IllegalArgumentException("No model '" + modelPath + "' is available. Available models: " +
- context.importedModels().all().stream().map(ImportedModel::source).collect(Collectors.joining(", ")));
+ context.importedModels().all().stream().map(ImportedMlModel::source).collect(Collectors.joining(", ")));
if (sourceModel != null) {
return fromSource(modelName,
@@ -116,7 +117,7 @@ public class ConvertedModel {
String modelDescription,
RankProfile rankProfile,
QueryProfileRegistry queryProfileRegistry,
- ImportedModel importedModel) {
+ ImportedMlModel importedModel) {
ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
return new ConvertedModel(modelName,
modelDescription,
@@ -187,7 +188,7 @@ public class ConvertedModel {
// ----------------------- Static model conversion/storage below here
- private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model,
+ private static Map<String, ExpressionFunction> convertAndStore(ImportedMlModel model,
RankProfile profile,
QueryProfileRegistry queryProfiles,
ModelStore store) {
@@ -202,7 +203,7 @@ public class ConvertedModel {
// Add expressions
Map<String, ExpressionFunction> expressions = new HashMap<>();
- for (ImportedModel.ImportedFunction outputFunction : model.outputExpressions()) {
+ for (ImportedMlFunction outputFunction : model.outputExpressions()) {
ExpressionFunction expression = asExpressionFunction(outputFunction);
addExpression(expression, expression.getName(),
constantsReplacedByFunctions,
@@ -219,7 +220,7 @@ public class ConvertedModel {
return expressions;
}
- private static ExpressionFunction asExpressionFunction(ImportedModel.ImportedFunction function) {
+ private static ExpressionFunction asExpressionFunction(ImportedMlFunction function) {
try {
Map<String, TensorType> argumentTypes = new HashMap<>();
for (Map.Entry<String, String> entry : function.argumentTypes().entrySet())
@@ -239,7 +240,7 @@ public class ConvertedModel {
private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
- ImportedModel model,
+ ImportedMlModel model,
ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
@@ -322,7 +323,7 @@ public class ConvertedModel {
* Verify that the inputs declared in the given expression exists in the given rank profile as functions,
* and return tensors of the correct types.
*/
- private static void verifyInputs(RankingExpression expression, ImportedModel model,
+ private static void verifyInputs(RankingExpression expression, ImportedMlModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> functionNames = new HashSet<>();
addFunctionNamesIn(expression.getRoot(), functionNames, model);
@@ -359,7 +360,7 @@ public class ConvertedModel {
}
/** Add the generated functions to the rank profile */
- private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) {
+ private static void addGeneratedFunctions(ImportedMlModel model, RankProfile profile) {
model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v)));
}
@@ -368,7 +369,7 @@ public class ConvertedModel {
* function specifies that a single exemplar should be evaluated, we can
* reduce the batch dimension out.
*/
- private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
@@ -396,7 +397,7 @@ public class ConvertedModel {
expression.setRoot(root);
}
- private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
@@ -502,7 +503,7 @@ public class ConvertedModel {
return node;
}
- private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedMlModel model) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
if (referenceNode.getOutput() == null) { // function references cannot specify outputs
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
index 9af5d82181e..131972ffc73 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.yolean.Exceptions;
import org.junit.Test;
@@ -26,7 +26,7 @@ public class IncorrectRankingExpressionFileRefTestCase extends SearchDefinitionT
Search search = SearchBuilder.buildFromFile("src/test/examples/incorrectrankingexpressionfileref.sd",
registry,
new QueryProfileRegistry());
- new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing
+ new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // cause rank profile parsing
fail("parsing should have failed");
} catch (IllegalArgumentException e) {
String message = Exceptions.toMessageString(e);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index c2b7eb90487..06761ad45bc 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import java.util.Iterator;
@@ -91,7 +91,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
assertEquals(8, rankProfile.getNumThreadsPerSearch());
assertEquals(70, rankProfile.getMinHitsPerThread());
assertEquals(1200, rankProfile.getNumSearchPartitions());
- RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), attributeFields);
+ RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedMlModels(), attributeFields);
assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").isPresent());
assertEquals("0.78", findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").get());
assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.numthreadspersearch").isPresent());
@@ -126,7 +126,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
}
private static void assertAttributeTypeSettings(RankProfile profile, Search search) {
- RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search));
+ RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search));
assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.a").get());
assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.attribute.b").get());
assertEquals("tensor(x[])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.c").get());
@@ -168,7 +168,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
}
private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) {
- RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search));
+ RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search));
assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor1").get());
assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor2").get());
assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent());
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
index 9ee32c7e6b9..1d6a75f039d 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -55,7 +55,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase {
assertEquals("query(a) = 1500", parent.getRankProperties().get(0).toString());
// Check derived model
- RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedModels(), attributeFields);
+ RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedMlModels(), attributeFields);
assertEquals("(query(a),1500)", rawParent.configProperties().get(0).toString());
}
@@ -67,7 +67,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase {
// Check derived model
RawRankProfile rawChild = new RawRankProfile(rankProfileRegistry.get(search, "child"),
new QueryProfileRegistry(),
- new ImportedModels(),
+ new ImportedMlModels(),
attributeFields);
assertEquals("(query(a),2000)", rawChild.configProperties().get(0).toString());
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index 6b6498528e4..af6507f352d 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -3,7 +3,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.yolean.Exceptions;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
@@ -67,19 +67,19 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile parent = rankProfileRegistry.get(s, "parent").compile(queryProfileRegistry, new ImportedModels());
+ RankProfile parent = rankProfileRegistry.get(s, "parent").compile(queryProfileRegistry, new ImportedMlModels());
assertEquals("0.0", parent.getFirstPhaseRanking().getRoot().toString());
- RankProfile child1 = rankProfileRegistry.get(s, "child1").compile(queryProfileRegistry, new ImportedModels());
+ RankProfile child1 = rankProfileRegistry.get(s, "child1").compile(queryProfileRegistry, new ImportedMlModels());
assertEquals("6.5", child1.getFirstPhaseRanking().getRoot().toString());
assertEquals("11.5", child1.getSecondPhaseRanking().getRoot().toString());
- RankProfile child2 = rankProfileRegistry.get(s, "child2").compile(queryProfileRegistry, new ImportedModels());
+ RankProfile child2 = rankProfileRegistry.get(s, "child2").compile(queryProfileRegistry, new ImportedMlModels());
assertEquals("16.6", child2.getFirstPhaseRanking().getRoot().toString());
assertEquals("foo: 14.0", child2.getFunctions().get("foo").function().getBody().toString());
List<Pair<String, String>> rankProperties = new RawRankProfile(child2,
queryProfileRegistry,
- new ImportedModels(),
+ new ImportedMlModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString());
assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString());
@@ -110,7 +110,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
try {
- rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
+ rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
fail("Should have caused an exception");
}
catch (IllegalArgumentException e) {
@@ -171,7 +171,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
RankProfile profile = rankProfileRegistry.get(s, "test");
assertEquals("safeLog(popShareSlowDecaySignal,myValue)", profile.getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString());
assertEquals("safeLog(popShareSlowDecaySignal,-9.21034037)",
- profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString());
+ profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString());
}
@Test
@@ -194,7 +194,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
Search s = builder.getSearch();
RankProfile profile = rankProfileRegistry.get(s, "test");
assertEquals("k1 + (k2 + k3) / 100000000.0",
- profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString());
+ profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString());
}
@Test
@@ -220,7 +220,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
Search s = builder.getSearch();
RankProfile profile = rankProfileRegistry.get(s, "test");
assertEquals("0.5 + 50 * (attribute(rating_yelp) - 3)",
- profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString());
+ profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
index facb258ce28..368f6fec80e 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
@@ -6,7 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import java.util.Optional;
@@ -63,10 +63,10 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
- RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))",
parent.getFirstPhaseRanking().getRoot().toString());
- RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("7.0 * (9 + attribute(a))",
child.getFirstPhaseRanking().getRoot().toString());
}
@@ -123,14 +123,14 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
- RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("17.0", parent.getFirstPhaseRanking().getRoot().toString());
assertEquals("0.0", parent.getSecondPhaseRanking().getRoot().toString());
assertEquals("10.0", getRankingExpression("foo", parent, s));
assertEquals("17.0", getRankingExpression("firstphase", parent, s));
assertEquals("0.0", getRankingExpression("secondphase", parent, s));
- RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("31.0 + bar + arg(4.0)", child.getFirstPhaseRanking().getRoot().toString());
assertEquals("24.0", child.getSecondPhaseRanking().getRoot().toString());
assertEquals("12.0", getRankingExpression("foo", child, s));
@@ -179,7 +179,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("attribute(a) + C + (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString());
assertEquals("attribute(a) + attribute(b)", getRankingExpression("C", test, s));
assertEquals("attribute(b) + 1", getRankingExpression("D", test, s));
@@ -210,7 +210,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
private String getRankingExpression(String name, RankProfile rankProfile, Search search) {
Optional<String> rankExpression =
- new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search))
+ new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search))
.configProperties()
.stream()
.filter(r -> r.getFirst().equals("rankingExpression(" + name + ").rankingScript"))
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 849a58b4c0e..a0deedb404a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import java.util.List;
@@ -45,10 +45,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
new QueryProfileRegistry(),
- new ImportedModels(),
+ new ImportedMlModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(sin).rankingScript,x * x)",
testRankProperties.get(0).toString());
@@ -89,10 +89,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
new QueryProfileRegistry(),
- new ImportedModels(),
+ new ImportedMlModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(tan).rankingScript,x * x)",
testRankProperties.get(0).toString());
@@ -139,10 +139,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
new QueryProfileRegistry(),
- new ImportedModels(),
+ new ImportedMlModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(sin).rankingScript,x * x)",
testRankProperties.get(0).toString());
@@ -203,10 +203,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels());
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedMlModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
queryProfiles,
- new ImportedModels(),
+ new ImportedMlModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
testRankProperties.get(0).toString());
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
index e4cf5aba5bd..830b7d531c3 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.yolean.Exceptions;
import org.junit.Test;
@@ -26,7 +26,7 @@ public class RankingExpressionValidationTestCase extends SearchDefinitionTestCas
try {
RankProfileRegistry registry = new RankProfileRegistry();
Search search = importWithExpression(expression, registry);
- new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing
+ new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // cause rank profile parsing
fail("No exception on incorrect ranking expression " + expression);
} catch (IllegalArgumentException e) {
// Success
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
index 2e942d4b8d7..5e8a4597a2d 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
@@ -7,7 +7,7 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.vespa.configmodel.producers.DocumentManager;
import com.yahoo.vespa.configmodel.producers.DocumentTypes;
@@ -37,7 +37,7 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase
DerivedConfiguration config = new DerivedConfiguration(builder.getSearch(searchDefinitionName),
builder.getRankProfileRegistry(),
builder.getQueryProfileRegistry(),
- new ImportedModels());
+ new ImportedMlModels());
return export(dirName, builder, config);
}
@@ -45,7 +45,7 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase
DerivedConfiguration config = new DerivedConfiguration(search,
builder.getRankProfileRegistry(),
builder.getQueryProfileRegistry(),
- new ImportedModels());
+ new ImportedMlModels());
return export(dirName, builder, config);
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
index 1ed9dcd9a74..2160dda45aa 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
@@ -9,7 +9,7 @@ import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
/**
@@ -32,7 +32,7 @@ public class EmptyRankProfileTestCase extends SearchDefinitionTestCase {
doc.addField(new SDField("c", DataType.STRING));
search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry());
- new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
+ new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
index 674ae5274b0..a7821615f48 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
@@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
@@ -42,7 +42,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase {
other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333));
new Processing().process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true, false);
- DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
+ DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels());
// Check attribute fields
derived.getAttributeFields(); // TODO: assert content
@@ -73,7 +73,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase {
other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333));
search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry());
- DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedModels());
+ DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedMlModels());
// Check il script addition
assertIndexing(Arrays.asList("clear_state | guard { input a | tokenize normalize stem:\"SHORTEST\" | index a; }",
@@ -100,7 +100,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase {
field2.setLiteralBoost(20);
search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry());
- new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
+ new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels());
assertIndexing(Arrays.asList("clear_state | guard { input title | tokenize normalize stem:\"SHORTEST\" | summary title | index title; }",
"clear_state | guard { input body | tokenize normalize stem:\"SHORTEST\" | summary body | index body; }",
"clear_state | guard { input title | tokenize | index title_literal; }",
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
index cfbcbc74f2d..61d1cd36f56 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
@@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import java.io.File;
@@ -36,7 +36,7 @@ public class SimpleInheritTestCase extends AbstractExportingTestCase {
DerivedConfiguration config = new DerivedConfiguration(search,
builder.getRankProfileRegistry(),
new QueryProfileRegistry(),
- new ImportedModels());
+ new ImportedMlModels());
config.export(toDirName);
checkDir(toDirName, expectedResultsDirName);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
index 24a6b0d6aaa..a34d4de4f51 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
@@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
@@ -34,7 +34,7 @@ public class TypeConversionTestCase extends SearchDefinitionTestCase {
document.addField(a);
new Processing().process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true, false);
- DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
+ DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels());
IndexInfo indexInfo = derived.getIndexInfo();
assertFalse(indexInfo.hasCommand("default", "compact-to-term"));
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
index 599c59c30a4..ae70061696b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
@@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import java.io.IOException;
@@ -98,7 +98,7 @@ public class ImplicitSearchFieldsTestCase extends SearchDefinitionTestCase {
sb.importFile("src/test/examples/nextgen/simple.sd");
sb.build();
assertNotNull(sb.getSearch());
- new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedModels());
+ new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedMlModels());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 862934935e4..9df03f25cb3 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -3,6 +3,7 @@ package com.yahoo.searchdefinition.processing;
import com.google.common.collect.ImmutableList;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.api.MlModelImporter;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
@@ -11,8 +12,7 @@ import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
-import ai.vespa.rankingexpression.importer.ModelImporter;
+import com.yahoo.config.model.api.ImportedMlModels;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
@@ -31,9 +31,9 @@ import static org.junit.Assert.assertEquals;
*/
class RankProfileSearchFixture {
- private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
- new OnnxImporter(),
- new XGBoostImporter());
+ private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new XGBoostImporter());
private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
private final QueryProfileRegistry queryProfileRegistry;
private Search search;
@@ -92,7 +92,7 @@ class RankProfileSearchFixture {
public RankProfile compileRankProfile(String rankProfile, Path applicationDir) {
RankProfile compiled = rankProfileRegistry.get(search, rankProfile)
.compile(queryProfileRegistry,
- new ImportedModels(applicationDir.toFile(), importers));
+ new ImportedMlModels(applicationDir.toFile(), importers));
compiledRankProfiles.put(rankProfile, compiled);
return compiled;
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index 301c0d1d31e..d8f1a2ba545 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import java.io.IOException;
@@ -39,7 +39,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase {
List<Pair<String, String>> rankProperties = new RawRankProfile(functionsRankProfile,
new QueryProfileRegistry(),
- new ImportedModels(),
+ new ImportedMlModels(),
new AttributeFields(search)).configProperties();
assertEquals(6, rankProperties.size());
@@ -65,7 +65,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase {
Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressioninfile",
registry,
new QueryProfileRegistry()).getSearch();
- new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // rank profile parsing happens during deriving
+ new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // rank profile parsing happens during deriving
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index dc5ad8524b5..fe1d722a49c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import ai.vespa.rankingexpression.importer.ImportedModels;
+import com.yahoo.config.model.api.ImportedMlModels;
import org.junit.Test;
import java.util.List;
@@ -200,10 +200,10 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
"}\n");
builder.build(true, new BaseDeployLogger());
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels());
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedMlModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
queryProfiles,
- new ImportedModels(),
+ new ImportedMlModels(),
new AttributeFields(s)).configProperties();
return testRankProperties;
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
index 125aefece0e..f5edc83da5c 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
@@ -3,12 +3,12 @@ package com.yahoo.vespa.model.ml;
import com.google.common.collect.ImmutableList;
import com.yahoo.config.model.ApplicationPackageTester;
+import com.yahoo.config.model.api.MlModelImporter;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankingConstant;
-import ai.vespa.rankingexpression.importer.ModelImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
@@ -32,9 +32,9 @@ import static org.junit.Assert.assertEquals;
*/
public class ImportedModelTester {
- private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
- new OnnxImporter(),
- new XGBoostImporter());
+ private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new XGBoostImporter());
private final String modelName;
private final Path applicationDir;
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 9971e78d3c5..5a2e7f0dbcd 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -29,6 +29,12 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>config-model-api</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>searchlib</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index c2235b9abe9..ec4e729f9c7 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.yahoo.collections.Pair;
+import com.yahoo.config.model.api.ImportedMlFunction;
+import com.yahoo.config.model.api.ImportedMlModel;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -23,7 +23,7 @@ import java.util.regex.Pattern;
*
* @author bratseth
*/
-public class ImportedModel {
+public class ImportedModel implements ImportedMlModel {
private static final String defaultSignatureName = "default";
@@ -52,15 +52,17 @@ public class ImportedModel {
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ @Override
public String name() { return name; }
/** Returns the source path (directory or file) of this model */
+ @Override
public String source() { return source; }
/** Returns an immutable map of the inputs of this */
public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
- // CFG
+ @Override
public Optional<String> inputTypeSpec(String input) {
return Optional.ofNullable(inputs.get(input)).map(TensorType::toString);
}
@@ -69,7 +71,7 @@ public class ImportedModel {
* Returns an immutable map of the small constants of this, represented as strings on the standard tensor form.
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
- // CFG
+ @Override
public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); }
boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }
@@ -79,7 +81,7 @@ public class ImportedModel {
* These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
* For TensorFlow this corresponds to Variable files stored separately.
*/
- // CFG
+ @Override
public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); }
boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
@@ -97,7 +99,7 @@ public class ImportedModel {
* Returns an immutable map of the functions that are part of this model.
* Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification.
*/
- // CFG
+ @Override
public Map<String, String> functions() { return asExpressionStrings(functions); }
/** Returns an immutable map of the signatures of this */
@@ -123,36 +125,36 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- // CFG
- public List<ImportedFunction> outputExpressions() {
- List<ImportedFunction> functions = new ArrayList<>();
+ @Override
+ public List<ImportedMlFunction> outputExpressions() {
+ List<ImportedMlFunction> functions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(),
signatureEntry.getKey() + "." + outputEntry.getKey()));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
- functions.add(new ImportedFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().values()),
- expressions().get(signatureEntry.getKey()),
- signatureEntry.getValue().inputMap(),
- Optional.empty()));
+ functions.add(new ImportedMlFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
+ expressions().get(signatureEntry.getKey()).getRoot().toString(),
+ asTensorTypeStrings(signatureEntry.getValue().inputMap()),
+ Optional.empty()));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- functions.add(new ImportedFunction(singleEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- singleEntry.getValue(),
- inputs,
- Optional.empty()));
+ functions.add(new ImportedMlFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue().getRoot().toString(),
+ asTensorTypeStrings(inputs),
+ Optional.empty()));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- functions.add(new ImportedFunction(expressionEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- expressionEntry.getValue(),
- inputs,
- Optional.empty()));
+ functions.add(new ImportedMlFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue().getRoot().toString(),
+ asTensorTypeStrings(inputs),
+ Optional.empty()));
}
}
}
@@ -172,6 +174,13 @@ public class ImportedModel {
return values;
}
+ private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) {
+ Map<String, String> stringMap = new HashMap<>();
+ for (Map.Entry<String, TensorType> entry : map.entrySet())
+ stringMap.put(entry.getKey(), entry.getValue().toString());
+ return stringMap;
+ }
+
private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
HashMap<String, String> values = new HashMap<>();
for (Map.Entry<String, RankingExpression> entry : map.entrySet())
@@ -246,16 +255,14 @@ public class ImportedModel {
}
/** Returns the expression this output references as an imported function */
- public ImportedFunction outputFunction(String outputName, String functionName) {
- return new ImportedFunction(functionName,
- new ArrayList<>(inputs.values()),
- owner().expressions().get(outputs.get(outputName)),
- inputMap(),
- Optional.empty());
+ public ImportedMlFunction outputFunction(String outputName, String functionName) {
+ return new ImportedMlFunction(functionName,
+ new ArrayList<>(inputs.values()),
+ owner().expressions().get(outputs.get(outputName)).getRoot().toString(),
+ asTensorTypeStrings(inputMap()),
+ Optional.empty());
}
- // CFG
-
@Override
public String toString() { return "signature '" + name + "'"; }
@@ -266,37 +273,4 @@ public class ImportedModel {
}
- // CFG
- public static class ImportedFunction {
-
- private final String name;
- private final List<String> arguments;
- private final Map<String, String> argumentTypes;
- private final String expression;
- private final Optional<String> returnType;
-
- public ImportedFunction(String name, List<String> arguments, RankingExpression expression,
- Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
- this.name = name;
- this.arguments = arguments;
- this.expression = expression.getRoot().toString();
- this.argumentTypes = asStrings(argumentTypes);
- this.returnType = returnType.map(TensorType::toString);
- }
-
- private static Map<String, String> asStrings(Map<String, TensorType> map) {
- Map<String, String> stringMap = new HashMap<>();
- for (Map.Entry<String, TensorType> entry : map.entrySet())
- stringMap.put(entry.getKey(), entry.getValue().toString());
- return stringMap;
- }
-
- public String name() { return name; }
- public List<String> arguments() { return Collections.unmodifiableList(arguments); }
- public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); }
- public String expression() { return expression; }
- public Optional<String> returnType() { return returnType; }
-
- }
-
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 8a885938bf9..0200a9032a5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
+import com.yahoo.config.model.api.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -27,20 +28,22 @@ import java.util.logging.Logger;
*
* @author lesters
*/
-public abstract class ModelImporter {
+public abstract class ModelImporter implements MlModelImporter {
private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
/** Returns whether the file or directory at the given path is of the type which can be imported by this */
+ @Override
public abstract boolean canImport(String modelPath);
- /** Imports the given model */
- public abstract ImportedModel importModel(String modelName, String modelPath);
-
- final ImportedModel importModel(String modelName, File modelPath) {
+ @Override
+ public final ImportedModel importModel(String modelName, File modelPath) {
return importModel(modelName, modelPath.toString());
}
+ /** Imports the given model */
+ public abstract ImportedModel importModel(String modelName, String modelPath);
+
/**
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
deleted file mode 100644
index 4473f306dcd..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
+++ /dev/null
@@ -1,11 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-/**
- * Model integration.
- *
- * CAUTION!: Config models depends on this API. It cannot be changed without ensuring compatibility with
- * old config models.
- */
-@ExportPackage
-package ai.vespa.rankingexpression.importer;
-
-import com.yahoo.osgi.annotation.ExportPackage;