diff options
34 files changed, 278 insertions, 235 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java new file mode 100644 index 00000000000..54cdf807878 --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlFunction.java @@ -0,0 +1,37 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model.api; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * An imported function of an imported machine-learned model + * + * @author bratseth + */ +public class ImportedMlFunction { + + private final String name; + private final List<String> arguments; + private final Map<String, String> argumentTypes; + private final String expression; + private final Optional<String> returnType; + + public ImportedMlFunction(String name, List<String> arguments, String expression, + Map<String, String> argumentTypes, Optional<String> returnType) { + this.name = name; + this.arguments = Collections.unmodifiableList(arguments); + this.expression = expression; + this.argumentTypes = Collections.unmodifiableMap(argumentTypes); + this.returnType = returnType; + } + + public String name() { return name; } + public List<String> arguments() { return arguments; } + public Map<String, String> argumentTypes() { return argumentTypes; } + public String expression() { return expression; } + public Optional<String> returnType() { return returnType; } + +} diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java new file mode 100644 index 00000000000..078e4c239d6 --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModel.java @@ -0,0 +1,23 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.config.model.api; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Config model view of an imported machine-learned model. + * + * @author bratseth + */ +public interface ImportedMlModel { + + String name(); + String source(); + Optional<String> inputTypeSpec(String input); + Map<String, String> smallConstants(); + Map<String, String> largeConstants(); + Map<String, String> functions(); + List<ImportedMlFunction> outputExpressions(); + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModels.java index bfdaaca1dd7..aeef81788b8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ImportedMlModels.java @@ -1,12 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer; +package com.yahoo.config.model.api; -import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -18,34 +18,51 @@ import java.util.Optional; * * @author bratseth */ -public class ImportedModels { +public class ImportedMlModels { /** All imported models, indexed by their names */ - private final ImmutableMap<String, ImportedModel> importedModels; + private final Map<String, ImportedMlModel> importedModels; /** Create a null imported models */ - public ImportedModels() { - importedModels = ImmutableMap.of(); + public ImportedMlModels() { + importedModels = Collections.emptyMap(); } - public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) { - Map<String, ImportedModel> models = new HashMap<>(); + public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) { + Map<String, ImportedMlModel> models = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read importRecursively(modelsDirectory, models, importers); - importedModels = ImmutableMap.copyOf(models); + importedModels = Collections.unmodifiableMap(models); + } + + /** + * Returns the model at the given location in the application package. + * + * @param modelPath the path to this model (file or directory, depending on model type) + * under the application package, both from the root or relative to the + * models directory works + * @return the model at this path or null if none + */ + public ImportedMlModel get(File modelPath) { + return importedModels.get(toName(modelPath)); + } + + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedMlModel> all() { + return importedModels.values(); } private static void importRecursively(File dir, - Map<String, ImportedModel> models, - Collection<ModelImporter> importers) { + Map<String, ImportedMlModel> models, + Collection<MlModelImporter> importers) { if ( ! dir.isDirectory()) return; Arrays.stream(dir.listFiles()).sorted().forEach(child -> { - Optional<ModelImporter> importer = findImporterOf(child, importers); + Optional<MlModelImporter> importer = findImporterOf(child, importers); if (importer.isPresent()) { String name = toName(child); - ImportedModel existing = models.get(name); + ImportedMlModel existing = models.get(name); if (existing != null) throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + " both resolve to the model name '" + name + "'"); @@ -57,33 +74,10 @@ public class ImportedModels { }); } - private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) { + private static Optional<MlModelImporter> findImporterOf(File path, Collection<MlModelImporter> importers) { return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); } - /** - * Returns the model at the given location in the application package. - * - * @param modelPath the path to this model (file or directory, depending on model type) - * under the application package, both from the root or relative to the - * models directory works - * @return the model at this path or null if none - */ - // CFG - public ImportedModel get(File modelPath) { - return importedModels.get(toName(modelPath)); - } - - public ImportedModel get(String modelName) { - return importedModels.get(modelName); - } - - /** Returns an immutable collection of all the imported models */ - // CFG - public Collection<ImportedModel> all() { - return importedModels.values(); - } - private static String toName(File modelFile) { Path modelPath = Path.fromString(modelFile.toString()); if (modelFile.isFile()) diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java b/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java new file mode 100644 index 00000000000..d24eeb2d55a --- /dev/null +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/MlModelImporter.java @@ -0,0 +1,16 @@ +package com.yahoo.config.model.api; + +import java.io.File; + +/** + * Config model view of a machine-learned model importer + * + * @author bratseth + */ +public interface MlModelImporter { + + boolean canImport(String modelPath); + + ImportedMlModel importModel(String modelName, File modelPath); + +} diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index abca92eeb79..06b2452dbd4 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -1,8 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.config.model.deploy; -import ai.vespa.rankingexpression.importer.ImportedModels; -import ai.vespa.rankingexpression.importer.ModelImporter; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.component.Version; import com.yahoo.component.Vtag; import com.yahoo.config.application.api.ApplicationPackage; @@ -11,6 +10,7 @@ import com.yahoo.config.application.api.FileRegistry; import com.yahoo.config.application.api.UnparsedConfigDefinition; import com.yahoo.config.model.api.ConfigDefinitionRepo; import com.yahoo.config.model.api.HostProvisioner; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ValidationParameters; import com.yahoo.config.model.application.provider.BaseDeployLogger; @@ -68,7 +68,7 @@ public class DeployState implements ConfigDefinitionStore { private final Zone zone; private final QueryProfiles queryProfiles; private final SemanticRules semanticRules; - private final ImportedModels importedModels; + private final ImportedMlModels importedModels; private final ValidationOverrides validationOverrides; private final Version wantedNodeVespaVersion; private final Instant now; @@ -93,7 +93,7 @@ public class DeployState implements ConfigDefinitionStore { Optional<ConfigDefinitionRepo> configDefinitionRepo, java.util.Optional<Model> previousModel, Set<Rotation> rotations, - Collection<ModelImporter> modelImporters, + Collection<MlModelImporter> modelImporters, Zone zone, QueryProfiles queryProfiles, SemanticRules semanticRules, @@ -114,8 +114,8 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated - this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR), - modelImporters); + this.importedModels = new ImportedMlModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR), + modelImporters); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty); this.wantedNodeVespaVersion = wantedNodeVespaVersion; @@ -230,7 +230,7 @@ public class DeployState implements ConfigDefinitionStore { public SemanticRules getSemanticRules() { return semanticRules; } /** The (machine learned) models imported from the models/ directory, as an unmodifiable map indexed by model name */ - public ImportedModels getImportedModels() { return importedModels; } + public ImportedMlModels getImportedModels() { return importedModels; } public Version getWantedNodeVespaVersion() { return wantedNodeVespaVersion; } @@ -247,7 +247,7 @@ public class DeployState implements ConfigDefinitionStore { private Optional<ConfigDefinitionRepo> configDefinitionRepo = Optional.empty(); private Optional<Model> previousModel = Optional.empty(); private Set<Rotation> rotations = new HashSet<>(); - private Collection<ModelImporter> modelImporters = Collections.emptyList(); + private Collection<MlModelImporter> modelImporters = Collections.emptyList(); private Zone zone = Zone.defaultZone(); private Instant now = Instant.now(); private Version wantedNodeVespaVersion = Vtag.currentVersion; @@ -297,7 +297,7 @@ public class DeployState implements ConfigDefinitionStore { return this; } - public Builder modelImporters(Collection<ModelImporter> modelImporters) { + public Builder modelImporters(Collection<MlModelImporter> modelImporters) { this.modelImporters = modelImporters; return this; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index e434babdeb4..7c0b90c35fa 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.query.profile.types.FieldDescription; @@ -647,7 +647,7 @@ public class RankProfile implements Serializable, Cloneable { * Returns a copy of this where the content is optimized for execution. * Compiled profiles should never be modified. */ - public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { + public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedMlModels importedModels) { try { RankProfile compiled = this.clone(); compiled.compileThis(queryProfiles, importedModels); @@ -658,7 +658,7 @@ public class RankProfile implements Serializable, Cloneable { } } - private void compileThis(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { + private void compileThis(QueryProfileRegistry queryProfiles, ImportedMlModels importedModels) { checkNameCollisions(getFunctions(), getConstants()); ExpressionTransforms expressionTransforms = new ExpressionTransforms(); @@ -688,7 +688,7 @@ public class RankProfile implements Serializable, Cloneable { private Map<String, RankingExpressionFunction> compileFunctions(Supplier<Map<String, RankingExpressionFunction>> functions, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Map<String, RankingExpressionFunction> inlineFunctions, ExpressionTransforms expressionTransforms) { Map<String, RankingExpressionFunction> compiledFunctions = new LinkedHashMap<>(); @@ -716,7 +716,7 @@ public class RankProfile implements Serializable, Cloneable { private RankingExpression compile(RankingExpression expression, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Map<String, Value> constants, Map<String, RankingExpressionFunction> inlineFunctions, ExpressionTransforms expressionTransforms) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 271c335cd1f..7dc4b815da6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.derived; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.config.ConfigInstance; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.application.api.DeployLogger; @@ -49,7 +49,7 @@ public class DerivedConfiguration { public DerivedConfiguration(Search search, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels) { + ImportedMlModels importedModels) { this(search, new BaseDeployLogger(), rankProfileRegistry, queryProfiles, importedModels); } @@ -68,7 +68,7 @@ public class DerivedConfiguration { DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels) { + ImportedMlModels importedModels) { Validator.ensureNotNull("Search definition", search); this.search = search; if ( ! search.isDocumentsOnly()) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 51e0c1d2f47..4c117e44857 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.derived; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; @@ -45,7 +45,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels) { + ImportedMlModels importedModels) { setName(search == null ? "default" : search.getName()); this.rankingConstants = rankingConstants; deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); @@ -53,7 +53,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Search search, AttributeFields attributeFields) { if (search != null) { // profiles belonging to a search have a default profile diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 97b6def07ab..b7f515cedd4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.derived; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.google.common.collect.ImmutableList; import com.yahoo.collections.Pair; import com.yahoo.compress.Compressor; @@ -50,7 +50,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { + public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) { this.name = rankProfile.getName(); compressedProperties = compress(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields).derive()); } @@ -148,7 +148,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { + public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) { RankProfile compiled = rankProfile.compile(queryProfiles, importedModels); attributeTypes = compiled.getAttributeTypes(); queryFeatureTypes = compiled.getQueryFeatureTypes(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 9043bf966a3..f20298cfe1a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -19,13 +19,13 @@ public class RankProfileTransformContext extends TransformContext { private final RankProfile rankProfile; private final QueryProfileRegistry queryProfiles; - private final ImportedModels importedModels; + private final ImportedMlModels importedModels; private final Map<String, RankProfile.RankingExpressionFunction> inlineFunctions; private final Map<String, String> rankProperties = new HashMap<>(); public RankProfileTransformContext(RankProfile rankProfile, QueryProfileRegistry queryProfiles, - ImportedModels importedModels, + ImportedMlModels importedModels, Map<String, Value> constants, Map<String, RankProfile.RankingExpressionFunction> inlineFunctions) { super(constants); @@ -37,7 +37,7 @@ public class RankProfileTransformContext extends TransformContext { public RankProfile rankProfile() { return rankProfile; } public QueryProfileRegistry queryProfiles() { return queryProfiles; } - public ImportedModels importedModels() { return importedModels; } + public ImportedMlModels importedModels() { return importedModels; } public Map<String, RankProfile.RankingExpressionFunction> inlineFunctions() { return inlineFunctions; } public Map<String, String> rankProperties() { return rankProperties; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index a01e7cffd84..58fc08d15e7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -1,8 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model; -import ai.vespa.rankingexpression.importer.ImportedModel; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModel; import com.yahoo.config.ConfigBuilder; import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; @@ -20,6 +19,7 @@ import com.yahoo.config.model.ConfigModelRepo; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.FileDistribution; import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.producer.AbstractConfigProducer; @@ -217,11 +217,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri * Creates a rank profile not attached to any search definition, for each imported model in the application package, * and adds it to the given rank profile registry. */ - private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedModels importedModels, + private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedMlModels importedModels, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { if ( ! importedModels.all().isEmpty()) { // models/ directory is available - for (ImportedModel model : importedModels.all()) { + for (ImportedMlModel model : importedModels.all()) { RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java index 8fa9dfa9e91..954f20f36c0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model; -import ai.vespa.rankingexpression.importer.ModelImporter; import com.google.inject.Inject; import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.config.application.api.ApplicationPackage; @@ -11,6 +10,7 @@ import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.model.api.ConfigModelPlugin; import com.yahoo.config.model.api.HostProvisioner; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.ModelCreateResult; @@ -44,7 +44,7 @@ public class VespaModelFactory implements ModelFactory { private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName()); private final ConfigModelRegistry configModelRegistry; - private final Collection<ModelImporter> modelImporters; + private final Collection<MlModelImporter> modelImporters; private final Zone zone; private final Clock clock; private final Version version; @@ -52,7 +52,7 @@ public class VespaModelFactory implements ModelFactory { /** Creates a factory for vespa models for this version of the source */ @Inject public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, - ComponentRegistry<ModelImporter> modelImporters, + ComponentRegistry<MlModelImporter> modelImporters, Zone zone) { this.version = Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro); List<ConfigModelBuilder> modelBuilders = new ArrayList<>(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 259ac5227ae..c834bea7be2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -1,11 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; -import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.config.model.api.ImportedMlFunction; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.ImportedMlModel; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; @@ -70,12 +71,12 @@ public class ConvertedModel { private final ImmutableMap<String, ExpressionFunction> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ - private final Optional<ImportedModel> sourceModel; + private final Optional<ImportedMlModel> sourceModel; private ConvertedModel(ModelName modelName, String modelDescription, Map<String, ExpressionFunction> expressions, - Optional<ImportedModel> sourceModel) { + Optional<ImportedMlModel> sourceModel) { this.modelName = modelName; this.modelDescription = modelDescription; this.expressions = ImmutableMap.copyOf(expressions); @@ -90,13 +91,13 @@ public class ConvertedModel { * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory */ public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { - ImportedModel sourceModel = // TODO: Convert to name here, make sure its done just one way + ImportedMlModel sourceModel = // TODO: Convert to name here, make sure its done just one way context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath)); ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile); if (sourceModel == null && ! new ModelStore(context.rankProfile().applicationPackage(), modelName).exists()) throw new IllegalArgumentException("No model '" + modelPath + "' is available. Available models: " + - context.importedModels().all().stream().map(ImportedModel::source).collect(Collectors.joining(", "))); + context.importedModels().all().stream().map(ImportedMlModel::source).collect(Collectors.joining(", "))); if (sourceModel != null) { return fromSource(modelName, @@ -116,7 +117,7 @@ public class ConvertedModel { String modelDescription, RankProfile rankProfile, QueryProfileRegistry queryProfileRegistry, - ImportedModel importedModel) { + ImportedMlModel importedModel) { ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); return new ConvertedModel(modelName, modelDescription, @@ -187,7 +188,7 @@ public class ConvertedModel { // ----------------------- Static model conversion/storage below here - private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model, + private static Map<String, ExpressionFunction> convertAndStore(ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles, ModelStore store) { @@ -202,7 +203,7 @@ public class ConvertedModel { // Add expressions Map<String, ExpressionFunction> expressions = new HashMap<>(); - for (ImportedModel.ImportedFunction outputFunction : model.outputExpressions()) { + for (ImportedMlFunction outputFunction : model.outputExpressions()) { ExpressionFunction expression = asExpressionFunction(outputFunction); addExpression(expression, expression.getName(), constantsReplacedByFunctions, @@ -219,7 +220,7 @@ public class ConvertedModel { return expressions; } - private static ExpressionFunction asExpressionFunction(ImportedModel.ImportedFunction function) { + private static ExpressionFunction asExpressionFunction(ImportedMlFunction function) { try { Map<String, TensorType> argumentTypes = new HashMap<>(); for (Map.Entry<String, String> entry : function.argumentTypes().entrySet()) @@ -239,7 +240,7 @@ public class ConvertedModel { private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, - ImportedModel model, + ImportedMlModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, @@ -322,7 +323,7 @@ public class ConvertedModel { * Verify that the inputs declared in the given expression exists in the given rank profile as functions, * and return tensors of the correct types. */ - private static void verifyInputs(RankingExpression expression, ImportedModel model, + private static void verifyInputs(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); @@ -359,7 +360,7 @@ public class ConvertedModel { } /** Add the generated functions to the rank profile */ - private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) { + private static void addGeneratedFunctions(ImportedMlModel model, RankProfile profile) { model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v))); } @@ -368,7 +369,7 @@ public class ConvertedModel { * function specifies that a single exemplar should be evaluated, we can * reduce the batch dimension out. */ - private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model, + private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); TensorType typeBeforeReducing = expression.getRoot().type(typeContext); @@ -396,7 +397,7 @@ public class ConvertedModel { expression.setRoot(root); } - private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, + private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, TypeContext<Reference> typeContext) { if (node instanceof TensorFunctionNode) { TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); @@ -502,7 +503,7 @@ public class ConvertedModel { return node; } - private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { + private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedMlModel model) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; if (referenceNode.getOutput() == null) { // function references cannot specify outputs diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java index 9af5d82181e..131972ffc73 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.yolean.Exceptions; import org.junit.Test; @@ -26,7 +26,7 @@ public class IncorrectRankingExpressionFileRefTestCase extends SearchDefinitionT Search search = SearchBuilder.buildFromFile("src/test/examples/incorrectrankingexpressionfileref.sd", registry, new QueryProfileRegistry()); - new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // cause rank profile parsing fail("parsing should have failed"); } catch (IllegalArgumentException e) { String message = Exceptions.toMessageString(e); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index c2b7eb90487..06761ad45bc 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.Iterator; @@ -91,7 +91,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertEquals(8, rankProfile.getNumThreadsPerSearch()); assertEquals(70, rankProfile.getMinHitsPerThread()); assertEquals(1200, rankProfile.getNumSearchPartitions()); - RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), attributeFields); + RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedMlModels(), attributeFields); assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").isPresent()); assertEquals("0.78", findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").get()); assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.numthreadspersearch").isPresent()); @@ -126,7 +126,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { } private static void assertAttributeTypeSettings(RankProfile profile, Search search) { - RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)); + RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search)); assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.a").get()); assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.attribute.b").get()); assertEquals("tensor(x[])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.c").get()); @@ -168,7 +168,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { } private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) { - RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)); + RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search)); assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor1").get()); assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor2").get()); assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java index 9ee32c7e6b9..1d6a75f039d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -55,7 +55,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { assertEquals("query(a) = 1500", parent.getRankProperties().get(0).toString()); // Check derived model - RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedModels(), attributeFields); + RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedMlModels(), attributeFields); assertEquals("(query(a),1500)", rawParent.configProperties().get(0).toString()); } @@ -67,7 +67,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { // Check derived model RawRankProfile rawChild = new RawRankProfile(rankProfileRegistry.get(search, "child"), new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), attributeFields); assertEquals("(query(a),2000)", rawChild.configProperties().get(0).toString()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index 6b6498528e4..af6507f352d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -3,7 +3,7 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; import com.yahoo.search.query.profile.QueryProfileRegistry; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; @@ -67,19 +67,19 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.get(s, "parent").compile(queryProfileRegistry, new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(queryProfileRegistry, new ImportedMlModels()); assertEquals("0.0", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child1 = rankProfileRegistry.get(s, "child1").compile(queryProfileRegistry, new ImportedModels()); + RankProfile child1 = rankProfileRegistry.get(s, "child1").compile(queryProfileRegistry, new ImportedMlModels()); assertEquals("6.5", child1.getFirstPhaseRanking().getRoot().toString()); assertEquals("11.5", child1.getSecondPhaseRanking().getRoot().toString()); - RankProfile child2 = rankProfileRegistry.get(s, "child2").compile(queryProfileRegistry, new ImportedModels()); + RankProfile child2 = rankProfileRegistry.get(s, "child2").compile(queryProfileRegistry, new ImportedMlModels()); assertEquals("16.6", child2.getFirstPhaseRanking().getRoot().toString()); assertEquals("foo: 14.0", child2.getFunctions().get("foo").function().getBody().toString()); List<Pair<String, String>> rankProperties = new RawRankProfile(child2, queryProfileRegistry, - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString()); assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString()); @@ -110,7 +110,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); try { - rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); fail("Should have caused an exception"); } catch (IllegalArgumentException e) { @@ -171,7 +171,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("safeLog(popShareSlowDecaySignal,myValue)", profile.getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString()); assertEquals("safeLog(popShareSlowDecaySignal,-9.21034037)", - profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("POP_SLOW_SCORE").function().getBody().getRoot().toString()); } @Test @@ -194,7 +194,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase Search s = builder.getSearch(); RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("k1 + (k2 + k3) / 100000000.0", - profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); } @Test @@ -220,7 +220,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase Search s = builder.getSearch(); RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("0.5 + 50 * (attribute(rating_yelp) - 3)", - profile.compile(new QueryProfileRegistry(), new ImportedModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedMlModels()).getFunctions().get("rank_default").function().getBody().getRoot().toString()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index facb258ce28..368f6fec80e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -6,7 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.Optional; @@ -63,10 +63,10 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("7.0 * (9 + attribute(a))", child.getFirstPhaseRanking().getRoot().toString()); } @@ -123,14 +123,14 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("17.0", parent.getFirstPhaseRanking().getRoot().toString()); assertEquals("0.0", parent.getSecondPhaseRanking().getRoot().toString()); assertEquals("10.0", getRankingExpression("foo", parent, s)); assertEquals("17.0", getRankingExpression("firstphase", parent, s)); assertEquals("0.0", getRankingExpression("secondphase", parent, s)); - RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("31.0 + bar + arg(4.0)", child.getFirstPhaseRanking().getRoot().toString()); assertEquals("24.0", child.getSecondPhaseRanking().getRoot().toString()); assertEquals("12.0", getRankingExpression("foo", child, s)); @@ -179,7 +179,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("attribute(a) + C + (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString()); assertEquals("attribute(a) + attribute(b)", getRankingExpression("C", test, s)); assertEquals("attribute(b) + 1", getRankingExpression("D", test, s)); @@ -210,7 +210,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase private String getRankingExpression(String name, RankProfile rankProfile, Search search) { Optional<String> rankExpression = - new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)) + new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(search)) .configProperties() .stream() .filter(r -> r.getFirst().equals("rankingExpression(" + name + ").rankingScript")) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 849a58b4c0e..a0deedb404a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.List; @@ -45,10 +45,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(sin).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -89,10 +89,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(tan).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -139,10 +139,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(sin).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -203,10 +203,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java index e4cf5aba5bd..830b7d531c3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.yolean.Exceptions; import org.junit.Test; @@ -26,7 +26,7 @@ public class RankingExpressionValidationTestCase extends SearchDefinitionTestCas try { RankProfileRegistry registry = new RankProfileRegistry(); Search search = importWithExpression(expression, registry); - new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // cause rank profile parsing fail("No exception on incorrect ranking expression " + expression); } catch (IllegalArgumentException e) { // Success diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java index 2e942d4b8d7..5e8a4597a2d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java @@ -7,7 +7,7 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.vespa.configmodel.producers.DocumentManager; import com.yahoo.vespa.configmodel.producers.DocumentTypes; @@ -37,7 +37,7 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase DerivedConfiguration config = new DerivedConfiguration(builder.getSearch(searchDefinitionName), builder.getRankProfileRegistry(), builder.getQueryProfileRegistry(), - new ImportedModels()); + new ImportedMlModels()); return export(dirName, builder, config); } @@ -45,7 +45,7 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase DerivedConfiguration config = new DerivedConfiguration(search, builder.getRankProfileRegistry(), builder.getQueryProfileRegistry(), - new ImportedModels()); + new ImportedMlModels()); return export(dirName, builder, config); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java index 1ed9dcd9a74..2160dda45aa 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java @@ -9,7 +9,7 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; /** @@ -32,7 +32,7 @@ public class EmptyRankProfileTestCase extends SearchDefinitionTestCase { doc.addField(new SDField("c", DataType.STRING)); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java index 674ae5274b0..a7821615f48 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java @@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; @@ -42,7 +42,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); new Processing().process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true, false); - DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); // Check attribute fields derived.getAttributeFields(); // TODO: assert content @@ -73,7 +73,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedModels()); + DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedMlModels()); // Check il script addition assertIndexing(Arrays.asList("clear_state | guard { input a | tokenize normalize stem:\"SHORTEST\" | index a; }", @@ -100,7 +100,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { field2.setLiteralBoost(20); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); assertIndexing(Arrays.asList("clear_state | guard { input title | tokenize normalize stem:\"SHORTEST\" | summary title | index title; }", "clear_state | guard { input body | tokenize normalize stem:\"SHORTEST\" | summary body | index body; }", "clear_state | guard { input title | tokenize | index title_literal; }", diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java index cfbcbc74f2d..61d1cd36f56 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.io.File; @@ -36,7 +36,7 @@ public class SimpleInheritTestCase extends AbstractExportingTestCase { DerivedConfiguration config = new DerivedConfiguration(search, builder.getRankProfileRegistry(), new QueryProfileRegistry(), - new ImportedModels()); + new ImportedMlModels()); config.export(toDirName); checkDir(toDirName, expectedResultsDirName); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java index 24a6b0d6aaa..a34d4de4f51 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java @@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; @@ -34,7 +34,7 @@ public class TypeConversionTestCase extends SearchDefinitionTestCase { document.addField(a); new Processing().process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true, false); - DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); + DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedMlModels()); IndexInfo indexInfo = derived.getIndexInfo(); assertFalse(indexInfo.hasCommand("default", "compact-to-term")); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java index 599c59c30a4..ae70061696b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java @@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.io.IOException; @@ -98,7 +98,7 @@ public class ImplicitSearchFieldsTestCase extends SearchDefinitionTestCase { sb.importFile("src/test/examples/nextgen/simple.sd"); sb.build(); assertNotNull(sb.getSearch()); - new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedModels()); + new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedMlModels()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 862934935e4..9df03f25cb3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -3,6 +3,7 @@ package com.yahoo.searchdefinition.processing; import com.google.common.collect.ImmutableList; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; @@ -11,8 +12,7 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; -import ai.vespa.rankingexpression.importer.ModelImporter; +import com.yahoo.config.model.api.ImportedMlModels; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; @@ -31,9 +31,9 @@ import static org.junit.Assert.assertEquals; */ class RankProfileSearchFixture { - private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), - new OnnxImporter(), - new XGBoostImporter()); + private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; private Search search; @@ -92,7 +92,7 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) .compile(queryProfileRegistry, - new ImportedModels(applicationDir.toFile(), importers)); + new ImportedMlModels(applicationDir.toFile(), importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index 301c0d1d31e..d8f1a2ba545 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.io.IOException; @@ -39,7 +39,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { List<Pair<String, String>> rankProperties = new RawRankProfile(functionsRankProfile, new QueryProfileRegistry(), - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(search)).configProperties(); assertEquals(6, rankProperties.size()); @@ -65,7 +65,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressioninfile", registry, new QueryProfileRegistry()).getSearch(); - new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // rank profile parsing happens during deriving + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // rank profile parsing happens during deriving } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index dc5ad8524b5..fe1d722a49c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import ai.vespa.rankingexpression.importer.ImportedModels; +import com.yahoo.config.model.api.ImportedMlModels; import org.junit.Test; import java.util.List; @@ -200,10 +200,10 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { "}\n"); builder.build(true, new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedMlModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, - new ImportedModels(), + new ImportedMlModels(), new AttributeFields(s)).configProperties(); return testRankProperties; } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 125aefece0e..f5edc83da5c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -3,12 +3,12 @@ package com.yahoo.vespa.model.ml; import com.google.common.collect.ImmutableList; import com.yahoo.config.model.ApplicationPackageTester; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; -import ai.vespa.rankingexpression.importer.ModelImporter; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; @@ -32,9 +32,9 @@ import static org.junit.Assert.assertEquals; */ public class ImportedModelTester { - private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), - new OnnxImporter(), - new XGBoostImporter()); + private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); private final String modelName; private final Path applicationDir; diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 9971e78d3c5..5a2e7f0dbcd 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -29,6 +29,12 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>config-model-api</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>searchlib</artifactId> <version>${project.version}</version> <scope>provided</scope> diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index c2235b9abe9..ec4e729f9c7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.yahoo.collections.Pair; +import com.yahoo.config.model.api.ImportedMlFunction; +import com.yahoo.config.model.api.ImportedMlModel; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -23,7 +23,7 @@ import java.util.regex.Pattern; * * @author bratseth */ -public class ImportedModel { +public class ImportedModel implements ImportedMlModel { private static final String defaultSignatureName = "default"; @@ -52,15 +52,17 @@ public class ImportedModel { } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + @Override public String name() { return name; } /** Returns the source path (directory or file) of this model */ + @Override public String source() { return source; } /** Returns an immutable map of the inputs of this */ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } - // CFG + @Override public Optional<String> inputTypeSpec(String input) { return Optional.ofNullable(inputs.get(input)).map(TensorType::toString); } @@ -69,7 +71,7 @@ public class ImportedModel { * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form. * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ - // CFG + @Override public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } @@ -79,7 +81,7 @@ public class ImportedModel { * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. * For TensorFlow this corresponds to Variable files stored separately. */ - // CFG + @Override public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } @@ -97,7 +99,7 @@ public class ImportedModel { * Returns an immutable map of the functions that are part of this model. * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ - // CFG + @Override public Map<String, String> functions() { return asExpressionStrings(functions); } /** Returns an immutable map of the signatures of this */ @@ -123,36 +125,36 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - // CFG - public List<ImportedFunction> outputExpressions() { - List<ImportedFunction> functions = new ArrayList<>(); + @Override + public List<ImportedMlFunction> outputExpressions() { + List<ImportedMlFunction> functions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(), signatureEntry.getKey() + "." + outputEntry.getKey())); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - functions.add(new ImportedFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().values()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty())); + functions.add(new ImportedMlFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().values()), + expressions().get(signatureEntry.getKey()).getRoot().toString(), + asTensorTypeStrings(signatureEntry.getValue().inputMap()), + Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - functions.add(new ImportedFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty())); + functions.add(new ImportedMlFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue().getRoot().toString(), + asTensorTypeStrings(inputs), + Optional.empty())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - functions.add(new ImportedFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty())); + functions.add(new ImportedMlFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue().getRoot().toString(), + asTensorTypeStrings(inputs), + Optional.empty())); } } } @@ -172,6 +174,13 @@ public class ImportedModel { return values; } + private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) { + Map<String, String> stringMap = new HashMap<>(); + for (Map.Entry<String, TensorType> entry : map.entrySet()) + stringMap.put(entry.getKey(), entry.getValue().toString()); + return stringMap; + } + private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { HashMap<String, String> values = new HashMap<>(); for (Map.Entry<String, RankingExpression> entry : map.entrySet()) @@ -246,16 +255,14 @@ public class ImportedModel { } /** Returns the expression this output references as an imported function */ - public ImportedFunction outputFunction(String outputName, String functionName) { - return new ImportedFunction(functionName, - new ArrayList<>(inputs.values()), - owner().expressions().get(outputs.get(outputName)), - inputMap(), - Optional.empty()); + public ImportedMlFunction outputFunction(String outputName, String functionName) { + return new ImportedMlFunction(functionName, + new ArrayList<>(inputs.values()), + owner().expressions().get(outputs.get(outputName)).getRoot().toString(), + asTensorTypeStrings(inputMap()), + Optional.empty()); } - // CFG - @Override public String toString() { return "signature '" + name + "'"; } @@ -266,37 +273,4 @@ public class ImportedModel { } - // CFG - public static class ImportedFunction { - - private final String name; - private final List<String> arguments; - private final Map<String, String> argumentTypes; - private final String expression; - private final Optional<String> returnType; - - public ImportedFunction(String name, List<String> arguments, RankingExpression expression, - Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { - this.name = name; - this.arguments = arguments; - this.expression = expression.getRoot().toString(); - this.argumentTypes = asStrings(argumentTypes); - this.returnType = returnType.map(TensorType::toString); - } - - private static Map<String, String> asStrings(Map<String, TensorType> map) { - Map<String, String> stringMap = new HashMap<>(); - for (Map.Entry<String, TensorType> entry : map.entrySet()) - stringMap.put(entry.getKey(), entry.getValue().toString()); - return stringMap; - } - - public String name() { return name; } - public List<String> arguments() { return Collections.unmodifiableList(arguments); } - public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); } - public String expression() { return expression; } - public Optional<String> returnType() { return returnType; } - - } - } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 8a885938bf9..0200a9032a5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -27,20 +28,22 @@ import java.util.logging.Logger; * * @author lesters */ -public abstract class ModelImporter { +public abstract class ModelImporter implements MlModelImporter { private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); /** Returns whether the file or directory at the given path is of the type which can be imported by this */ + @Override public abstract boolean canImport(String modelPath); - /** Imports the given model */ - public abstract ImportedModel importModel(String modelName, String modelPath); - - final ImportedModel importModel(String modelName, File modelPath) { + @Override + public final ImportedModel importModel(String modelName, File modelPath) { return importModel(modelName, modelPath.toString()); } + /** Imports the given model */ + public abstract ImportedModel importModel(String modelName, String modelPath); + /** * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java deleted file mode 100644 index 4473f306dcd..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * Model integration. - * - * CAUTION!: Config models depends on this API. It cannot be changed without ensuring compatibility with - * old config models. - */ -@ExportPackage -package ai.vespa.rankingexpression.importer; - -import com.yahoo.osgi.annotation.ExportPackage; |