summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-26 15:00:16 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-26 15:00:16 +0100
commite3c64c48445d85054292b6fe90c3347d5ec327f2 (patch)
tree9c4f8509699a74ec614633159d1eb3171482d456 /model-integration
parent04cc3c48130b8397c04335948e5971914b2eaf22 (diff)
Model the config model view of model-integration
This is to make it clearer that these methods are part of the config model API.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java106
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java111
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java13
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java11
5 files changed, 54 insertions, 193 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 9971e78d3c5..5a2e7f0dbcd 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -29,6 +29,12 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>config-model-api</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>searchlib</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index c2235b9abe9..ec4e729f9c7 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.yahoo.collections.Pair;
+import com.yahoo.config.model.api.ImportedMlFunction;
+import com.yahoo.config.model.api.ImportedMlModel;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -23,7 +23,7 @@ import java.util.regex.Pattern;
*
* @author bratseth
*/
-public class ImportedModel {
+public class ImportedModel implements ImportedMlModel {
private static final String defaultSignatureName = "default";
@@ -52,15 +52,17 @@ public class ImportedModel {
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ @Override
public String name() { return name; }
/** Returns the source path (directory or file) of this model */
+ @Override
public String source() { return source; }
/** Returns an immutable map of the inputs of this */
public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
- // CFG
+ @Override
public Optional<String> inputTypeSpec(String input) {
return Optional.ofNullable(inputs.get(input)).map(TensorType::toString);
}
@@ -69,7 +71,7 @@ public class ImportedModel {
* Returns an immutable map of the small constants of this, represented as strings on the standard tensor form.
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
- // CFG
+ @Override
public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); }
boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }
@@ -79,7 +81,7 @@ public class ImportedModel {
* These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
* For TensorFlow this corresponds to Variable files stored separately.
*/
- // CFG
+ @Override
public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); }
boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
@@ -97,7 +99,7 @@ public class ImportedModel {
* Returns an immutable map of the functions that are part of this model.
* Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification.
*/
- // CFG
+ @Override
public Map<String, String> functions() { return asExpressionStrings(functions); }
/** Returns an immutable map of the signatures of this */
@@ -123,36 +125,36 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- // CFG
- public List<ImportedFunction> outputExpressions() {
- List<ImportedFunction> functions = new ArrayList<>();
+ @Override
+ public List<ImportedMlFunction> outputExpressions() {
+ List<ImportedMlFunction> functions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(),
signatureEntry.getKey() + "." + outputEntry.getKey()));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
- functions.add(new ImportedFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().values()),
- expressions().get(signatureEntry.getKey()),
- signatureEntry.getValue().inputMap(),
- Optional.empty()));
+ functions.add(new ImportedMlFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
+ expressions().get(signatureEntry.getKey()).getRoot().toString(),
+ asTensorTypeStrings(signatureEntry.getValue().inputMap()),
+ Optional.empty()));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- functions.add(new ImportedFunction(singleEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- singleEntry.getValue(),
- inputs,
- Optional.empty()));
+ functions.add(new ImportedMlFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue().getRoot().toString(),
+ asTensorTypeStrings(inputs),
+ Optional.empty()));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- functions.add(new ImportedFunction(expressionEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- expressionEntry.getValue(),
- inputs,
- Optional.empty()));
+ functions.add(new ImportedMlFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue().getRoot().toString(),
+ asTensorTypeStrings(inputs),
+ Optional.empty()));
}
}
}
@@ -172,6 +174,13 @@ public class ImportedModel {
return values;
}
+ private static Map<String, String> asTensorTypeStrings(Map<String, TensorType> map) {
+ Map<String, String> stringMap = new HashMap<>();
+ for (Map.Entry<String, TensorType> entry : map.entrySet())
+ stringMap.put(entry.getKey(), entry.getValue().toString());
+ return stringMap;
+ }
+
private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
HashMap<String, String> values = new HashMap<>();
for (Map.Entry<String, RankingExpression> entry : map.entrySet())
@@ -246,16 +255,14 @@ public class ImportedModel {
}
/** Returns the expression this output references as an imported function */
- public ImportedFunction outputFunction(String outputName, String functionName) {
- return new ImportedFunction(functionName,
- new ArrayList<>(inputs.values()),
- owner().expressions().get(outputs.get(outputName)),
- inputMap(),
- Optional.empty());
+ public ImportedMlFunction outputFunction(String outputName, String functionName) {
+ return new ImportedMlFunction(functionName,
+ new ArrayList<>(inputs.values()),
+ owner().expressions().get(outputs.get(outputName)).getRoot().toString(),
+ asTensorTypeStrings(inputMap()),
+ Optional.empty());
}
- // CFG
-
@Override
public String toString() { return "signature '" + name + "'"; }
@@ -266,37 +273,4 @@ public class ImportedModel {
}
- // CFG
- public static class ImportedFunction {
-
- private final String name;
- private final List<String> arguments;
- private final Map<String, String> argumentTypes;
- private final String expression;
- private final Optional<String> returnType;
-
- public ImportedFunction(String name, List<String> arguments, RankingExpression expression,
- Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
- this.name = name;
- this.arguments = arguments;
- this.expression = expression.getRoot().toString();
- this.argumentTypes = asStrings(argumentTypes);
- this.returnType = returnType.map(TensorType::toString);
- }
-
- private static Map<String, String> asStrings(Map<String, TensorType> map) {
- Map<String, String> stringMap = new HashMap<>();
- for (Map.Entry<String, TensorType> entry : map.entrySet())
- stringMap.put(entry.getKey(), entry.getValue().toString());
- return stringMap;
- }
-
- public String name() { return name; }
- public List<String> arguments() { return Collections.unmodifiableList(arguments); }
- public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); }
- public String expression() { return expression; }
- public Optional<String> returnType() { return returnType; }
-
- }
-
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
deleted file mode 100644
index bfdaaca1dd7..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
+++ /dev/null
@@ -1,111 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.rankingexpression.importer;
-
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.path.Path;
-
-import java.io.File;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
-
-/**
- * All models imported from the models/ directory in the application package.
- * If this is empty it may be due to either not having any models in the application package,
- * or this being created for a ZooKeeper application package, which does not have imported models.
- *
- * @author bratseth
- */
-public class ImportedModels {
-
- /** All imported models, indexed by their names */
- private final ImmutableMap<String, ImportedModel> importedModels;
-
- /** Create a null imported models */
- public ImportedModels() {
- importedModels = ImmutableMap.of();
- }
-
- public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
- Map<String, ImportedModel> models = new HashMap<>();
-
- // Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, models, importers);
- importedModels = ImmutableMap.copyOf(models);
- }
-
- private static void importRecursively(File dir,
- Map<String, ImportedModel> models,
- Collection<ModelImporter> importers) {
- if ( ! dir.isDirectory()) return;
-
- Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
- Optional<ModelImporter> importer = findImporterOf(child, importers);
- if (importer.isPresent()) {
- String name = toName(child);
- ImportedModel existing = models.get(name);
- if (existing != null)
- throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
- " both resolve to the model name '" + name + "'");
- models.put(name, importer.get().importModel(name, child));
- }
- else {
- importRecursively(child, models, importers);
- }
- });
- }
-
- private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
- return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
- }
-
- /**
- * Returns the model at the given location in the application package.
- *
- * @param modelPath the path to this model (file or directory, depending on model type)
- * under the application package, both from the root or relative to the
- * models directory works
- * @return the model at this path or null if none
- */
- // CFG
- public ImportedModel get(File modelPath) {
- return importedModels.get(toName(modelPath));
- }
-
- public ImportedModel get(String modelName) {
- return importedModels.get(modelName);
- }
-
- /** Returns an immutable collection of all the imported models */
- // CFG
- public Collection<ImportedModel> all() {
- return importedModels.values();
- }
-
- private static String toName(File modelFile) {
- Path modelPath = Path.fromString(modelFile.toString());
- if (modelFile.isFile())
- modelPath = stripFileEnding(modelPath);
- String localPath = concatenateAfterModelsDirectory(modelPath);
- return localPath.replace('.', '_');
- }
-
- private static Path stripFileEnding(Path path) {
- int dotIndex = path.last().lastIndexOf(".");
- if (dotIndex <= 0) return path;
- return path.withLast(path.last().substring(0, dotIndex));
- }
-
- private static String concatenateAfterModelsDirectory(Path path) {
- boolean afterModels = false;
- StringBuilder result = new StringBuilder();
- for (String element : path.elements()) {
- if (afterModels) result.append(element).append("_");
- if (element.equals("models")) afterModels = true;
- }
- return result.substring(0, result.length()-1);
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 8a885938bf9..0200a9032a5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
+import com.yahoo.config.model.api.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -27,20 +28,22 @@ import java.util.logging.Logger;
*
* @author lesters
*/
-public abstract class ModelImporter {
+public abstract class ModelImporter implements MlModelImporter {
private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
/** Returns whether the file or directory at the given path is of the type which can be imported by this */
+ @Override
public abstract boolean canImport(String modelPath);
- /** Imports the given model */
- public abstract ImportedModel importModel(String modelName, String modelPath);
-
- final ImportedModel importModel(String modelName, File modelPath) {
+ @Override
+ public final ImportedModel importModel(String modelName, File modelPath) {
return importModel(modelName, modelPath.toString());
}
+ /** Imports the given model */
+ public abstract ImportedModel importModel(String modelName, String modelPath);
+
/**
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
deleted file mode 100644
index 4473f306dcd..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
+++ /dev/null
@@ -1,11 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-/**
- * Model integration.
- *
- * CAUTION!: Config models depends on this API. It cannot be changed without ensuring compatibility with
- * old config models.
- */
-@ExportPackage
-package ai.vespa.rankingexpression.importer;
-
-import com.yahoo.osgi.annotation.ExportPackage;