summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-26 21:34:32 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-26 21:34:32 +0100
commit8386d1e455a1dc50669dc450666c305cf1dadb0a (patch)
treeb39885acbe359e36684a73651c6cb7a43069eb54 /model-integration
parent9914692162d87882c8777b9557dbf4cf9e415ac6 (diff)
Create a config model view (api) package under model-integration
This is to avoid having to install config-mode and dependencies in the container at startup as a consequence of wanting model-integration there to make TensorFlow available.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlFunction.java37
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java105
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/MlModelImporter.java16
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java9
7 files changed, 193 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index ec4e729f9c7..0c5866b87fa 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -2,8 +2,8 @@
package ai.vespa.rankingexpression.importer;
import com.google.common.collect.ImmutableMap;
-import com.yahoo.config.model.api.ImportedMlFunction;
-import com.yahoo.config.model.api.ImportedMlModel;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 0200a9032a5..54c19211277 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
-import com.yahoo.config.model.api.MlModelImporter;
+import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlFunction.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlFunction.java
new file mode 100644
index 00000000000..2367ac0a4c7
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlFunction.java
@@ -0,0 +1,37 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.configmodelview;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * An imported function of an imported machine-learned model
+ *
+ * @author bratseth
+ */
+public class ImportedMlFunction {
+
+ private final String name;
+ private final List<String> arguments;
+ private final Map<String, String> argumentTypes;
+ private final String expression;
+ private final Optional<String> returnType;
+
+ public ImportedMlFunction(String name, List<String> arguments, String expression,
+ Map<String, String> argumentTypes, Optional<String> returnType) {
+ this.name = name;
+ this.arguments = Collections.unmodifiableList(arguments);
+ this.expression = expression;
+ this.argumentTypes = Collections.unmodifiableMap(argumentTypes);
+ this.returnType = returnType;
+ }
+
+ public String name() { return name; }
+ public List<String> arguments() { return arguments; }
+ public Map<String, String> argumentTypes() { return argumentTypes; }
+ public String expression() { return expression; }
+ public Optional<String> returnType() { return returnType; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
new file mode 100644
index 00000000000..e40a06af042
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModel.java
@@ -0,0 +1,23 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.configmodelview;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Config model view of an imported machine-learned model.
+ *
+ * @author bratseth
+ */
+public interface ImportedMlModel {
+
+ String name();
+ String source();
+ Optional<String> inputTypeSpec(String input);
+ Map<String, String> smallConstants();
+ Map<String, String> largeConstants();
+ Map<String, String> functions();
+ List<ImportedMlFunction> outputExpressions();
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
new file mode 100644
index 00000000000..f847af14ed4
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
@@ -0,0 +1,105 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.configmodelview;
+
+import com.yahoo.path.Path;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * All models imported from the models/ directory in the application package.
+ * If this is empty it may be due to either not having any models in the application package,
+ * or this being created for a ZooKeeper application package, which does not have imported models.
+ *
+ * @author bratseth
+ */
+public class ImportedMlModels {
+
+ /** All imported models, indexed by their names */
+ private final Map<String, ImportedMlModel> importedModels;
+
+ /** Create a null imported models */
+ public ImportedMlModels() {
+ importedModels = Collections.emptyMap();
+ }
+
+ public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) {
+ Map<String, ImportedMlModel> models = new HashMap<>();
+
+ // Find all subdirectories recursively which contains a model we can read
+ importRecursively(modelsDirectory, models, importers);
+ importedModels = Collections.unmodifiableMap(models);
+ }
+
+ /**
+ * Returns the model at the given location in the application package.
+ *
+ * @param modelPath the path to this model (file or directory, depending on model type)
+ * under the application package, both from the root or relative to the
+ * models directory works
+ * @return the model at this path or null if none
+ */
+ public ImportedMlModel get(File modelPath) {
+ return importedModels.get(toName(modelPath));
+ }
+
+ /** Returns an immutable collection of all the imported models */
+ public Collection<ImportedMlModel> all() {
+ return importedModels.values();
+ }
+
+ private static void importRecursively(File dir,
+ Map<String, ImportedMlModel> models,
+ Collection<MlModelImporter> importers) {
+ if ( ! dir.isDirectory()) return;
+
+ Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
+ Optional<MlModelImporter> importer = findImporterOf(child, importers);
+ if (importer.isPresent()) {
+ String name = toName(child);
+ ImportedMlModel existing = models.get(name);
+ if (existing != null)
+ throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
+ " both resolve to the model name '" + name + "'");
+ models.put(name, importer.get().importModel(name, child));
+ }
+ else {
+ importRecursively(child, models, importers);
+ }
+ });
+ }
+
+ private static Optional<MlModelImporter> findImporterOf(File path, Collection<MlModelImporter> importers) {
+ return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
+ }
+
+ private static String toName(File modelFile) {
+ Path modelPath = Path.fromString(modelFile.toString());
+ if (modelFile.isFile())
+ modelPath = stripFileEnding(modelPath);
+ String localPath = concatenateAfterModelsDirectory(modelPath);
+ return localPath.replace('.', '_');
+ }
+
+ private static Path stripFileEnding(Path path) {
+ int dotIndex = path.last().lastIndexOf(".");
+ if (dotIndex <= 0) return path;
+ return path.withLast(path.last().substring(0, dotIndex));
+ }
+
+ private static String concatenateAfterModelsDirectory(Path path) {
+ boolean afterModels = false;
+ StringBuilder result = new StringBuilder();
+ for (String element : path.elements()) {
+ if (afterModels) result.append(element).append("_");
+ if (element.equals("models")) afterModels = true;
+ }
+ return result.substring(0, result.length()-1);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/MlModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/MlModelImporter.java
new file mode 100644
index 00000000000..d294872113b
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/MlModelImporter.java
@@ -0,0 +1,16 @@
+package ai.vespa.rankingexpression.importer.configmodelview;
+
+import java.io.File;
+
+/**
+ * Config model view of a machine-learned model importer
+ *
+ * @author bratseth
+ */
+public interface MlModelImporter {
+
+ boolean canImport(String modelPath);
+
+ ImportedMlModel importModel(String modelName, File modelPath);
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java
new file mode 100644
index 00000000000..5a844bb5773
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/package-info.java
@@ -0,0 +1,9 @@
+/**
+ * The config models view of imported models. This API cannot be changed withoug taking earlier config models
+ * into account, not even on major versions.
+ */
+@ExportPackage
+package ai.vespa.rankingexpression.importer.configmodelview;
+
+import com.yahoo.osgi.annotation.ExportPackage;
+