diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-26 21:34:32 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-26 21:34:32 +0100 |
commit | 8386d1e455a1dc50669dc450666c305cf1dadb0a (patch) | |
tree | b39885acbe359e36684a73651c6cb7a43069eb54 /model-integration | |
parent | 9914692162d87882c8777b9557dbf4cf9e415ac6 (diff) |
Create a config model view (api) package under model-integration
This is to avoid having to install config-mode and dependencies
in the container at startup as a consequence of wanting model-integration
there to make TensorFlow available.
Diffstat (limited to 'model-integration')
7 files changed, 193 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index ec4e729f9c7..0c5866b87fa 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -2,8 +2,8 @@ package ai.vespa.rankingexpression.importer; import com.google.common.collect.ImmutableMap; -import com.yahoo.config.model.api.ImportedMlFunction; -import com.yahoo.config.model.api.ImportedMlModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 0200a9032a5..54c19211277 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; -import com.yahoo.config.model.api.MlModelImporter; +import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlFunction.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlFunction.java new file mode 100644 index 00000000000..2367ac0a4c7 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlFunction.java @@ -0,0 +1,37 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.configmodelview; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * An imported function of an imported machine-learned model + * + * @author bratseth + */ +public class ImportedMlFunction { + + private final String name; + private final List<String> arguments; + private final Map<String, String> argumentTypes; + private final String expression; + private final Optional<String> returnType; + + public ImportedMlFunction(String name, List<String> arguments, String expression, + Map<String, String> argumentTypes, Optional<String> returnType) { + this.name = name; + this.arguments = Collections.unmodifiableList(arguments); + this.expression = expression; + this.argumentTypes = Collections.unmodifiableMap(argumentTypes); + this.returnType = returnType; + } + + public String name() { return name; } + public List<String> arguments() { return arguments; } + public Map<String, String> argumentTypes() { return argumentTypes; } + public String expression() { return expression; } + public Optional<String> returnType() { return returnType; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java new file mode 100644 index 00000000000..e40a06af042 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java @@ -0,0 +1,23 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.configmodelview; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Config model view of an imported machine-learned model. + * + * @author bratseth + */ +public interface ImportedMlModel { + + String name(); + String source(); + Optional<String> inputTypeSpec(String input); + Map<String, String> smallConstants(); + Map<String, String> largeConstants(); + Map<String, String> functions(); + List<ImportedMlFunction> outputExpressions(); + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java new file mode 100644 index 00000000000..f847af14ed4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java @@ -0,0 +1,105 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.configmodelview; + +import com.yahoo.path.Path; + +import java.io.File; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +/** + * All models imported from the models/ directory in the application package. + * If this is empty it may be due to either not having any models in the application package, + * or this being created for a ZooKeeper application package, which does not have imported models. + * + * @author bratseth + */ +public class ImportedMlModels { + + /** All imported models, indexed by their names */ + private final Map<String, ImportedMlModel> importedModels; + + /** Create a null imported models */ + public ImportedMlModels() { + importedModels = Collections.emptyMap(); + } + + public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) { + Map<String, ImportedMlModel> models = new HashMap<>(); + + // Find all subdirectories recursively which contains a model we can read + importRecursively(modelsDirectory, models, importers); + importedModels = Collections.unmodifiableMap(models); + } + + /** + * Returns the model at the given location in the application package. + * + * @param modelPath the path to this model (file or directory, depending on model type) + * under the application package, both from the root or relative to the + * models directory works + * @return the model at this path or null if none + */ + public ImportedMlModel get(File modelPath) { + return importedModels.get(toName(modelPath)); + } + + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedMlModel> all() { + return importedModels.values(); + } + + private static void importRecursively(File dir, + Map<String, ImportedMlModel> models, + Collection<MlModelImporter> importers) { + if ( ! dir.isDirectory()) return; + + Arrays.stream(dir.listFiles()).sorted().forEach(child -> { + Optional<MlModelImporter> importer = findImporterOf(child, importers); + if (importer.isPresent()) { + String name = toName(child); + ImportedMlModel existing = models.get(name); + if (existing != null) + throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + + " both resolve to the model name '" + name + "'"); + models.put(name, importer.get().importModel(name, child)); + } + else { + importRecursively(child, models, importers); + } + }); + } + + private static Optional<MlModelImporter> findImporterOf(File path, Collection<MlModelImporter> importers) { + return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); + } + + private static String toName(File modelFile) { + Path modelPath = Path.fromString(modelFile.toString()); + if (modelFile.isFile()) + modelPath = stripFileEnding(modelPath); + String localPath = concatenateAfterModelsDirectory(modelPath); + return localPath.replace('.', '_'); + } + + private static Path stripFileEnding(Path path) { + int dotIndex = path.last().lastIndexOf("."); + if (dotIndex <= 0) return path; + return path.withLast(path.last().substring(0, dotIndex)); + } + + private static String concatenateAfterModelsDirectory(Path path) { + boolean afterModels = false; + StringBuilder result = new StringBuilder(); + for (String element : path.elements()) { + if (afterModels) result.append(element).append("_"); + if (element.equals("models")) afterModels = true; + } + return result.substring(0, result.length()-1); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/MlModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/MlModelImporter.java new file mode 100644 index 00000000000..d294872113b --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/MlModelImporter.java @@ -0,0 +1,16 @@ +package ai.vespa.rankingexpression.importer.configmodelview; + +import java.io.File; + +/** + * Config model view of a machine-learned model importer + * + * @author bratseth + */ +public interface MlModelImporter { + + boolean canImport(String modelPath); + + ImportedMlModel importModel(String modelName, File modelPath); + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java new file mode 100644 index 00000000000..5a844bb5773 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java @@ -0,0 +1,9 @@ +/** + * The config models view of imported models. This API cannot be changed withoug taking earlier config models + * into account, not even on major versions. + */ +@ExportPackage +package ai.vespa.rankingexpression.importer.configmodelview; + +import com.yahoo.osgi.annotation.ExportPackage; + |