From e3c64c48445d85054292b6fe90c3347d5ec327f2 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 26 Nov 2018 15:00:16 +0100 Subject: Model the config model view of model-integration This is to make it clearer that these methods are part of the config model API. --- .../rankingexpression/importer/ImportedModel.java | 106 ++++++++------------ .../rankingexpression/importer/ImportedModels.java | 111 --------------------- .../rankingexpression/importer/ModelImporter.java | 13 ++- .../rankingexpression/importer/package-info.java | 11 -- 4 files changed, 48 insertions(+), 193 deletions(-) delete mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java delete mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java (limited to 'model-integration/src') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index c2235b9abe9..ec4e729f9c7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.yahoo.collections.Pair; +import com.yahoo.config.model.api.ImportedMlFunction; +import com.yahoo.config.model.api.ImportedMlModel; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -23,7 +23,7 @@ import java.util.regex.Pattern; * * @author bratseth */ -public class ImportedModel { +public class ImportedModel implements ImportedMlModel { private static final String defaultSignatureName = "default"; @@ -52,15 +52,17 @@ public class ImportedModel { } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + @Override public String name() { return name; } /** Returns the source path (directory or file) of this model */ + @Override public String source() { return source; } /** Returns an immutable map of the inputs of this */ public Map inputs() { return Collections.unmodifiableMap(inputs); } - // CFG + @Override public Optional inputTypeSpec(String input) { return Optional.ofNullable(inputs.get(input)).map(TensorType::toString); } @@ -69,7 +71,7 @@ public class ImportedModel { * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form. * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ - // CFG + @Override public Map smallConstants() { return asTensorStrings(smallConstants); } boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } @@ -79,7 +81,7 @@ public class ImportedModel { * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. * For TensorFlow this corresponds to Variable files stored separately. */ - // CFG + @Override public Map largeConstants() { return asTensorStrings(largeConstants); } boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } @@ -97,7 +99,7 @@ public class ImportedModel { * Returns an immutable map of the functions that are part of this model. * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ - // CFG + @Override public Map functions() { return asExpressionStrings(functions); } /** Returns an immutable map of the signatures of this */ @@ -123,36 +125,36 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - // CFG - public List outputExpressions() { - List functions = new ArrayList<>(); + @Override + public List outputExpressions() { + List functions = new ArrayList<>(); for (Map.Entry signatureEntry : signatures().entrySet()) { for (Map.Entry outputEntry : signatureEntry.getValue().outputs().entrySet()) functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(), signatureEntry.getKey() + "." + outputEntry.getKey())); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - functions.add(new ImportedFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().values()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty())); + functions.add(new ImportedMlFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().values()), + expressions().get(signatureEntry.getKey()).getRoot().toString(), + asTensorTypeStrings(signatureEntry.getValue().inputMap()), + Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry singleEntry = this.expressions.entrySet().iterator().next(); - functions.add(new ImportedFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty())); + functions.add(new ImportedMlFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue().getRoot().toString(), + asTensorTypeStrings(inputs), + Optional.empty())); } else { for (Map.Entry expressionEntry : expressions().entrySet()) { - functions.add(new ImportedFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty())); + functions.add(new ImportedMlFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue().getRoot().toString(), + asTensorTypeStrings(inputs), + Optional.empty())); } } } @@ -172,6 +174,13 @@ public class ImportedModel { return values; } + private static Map asTensorTypeStrings(Map map) { + Map stringMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) + stringMap.put(entry.getKey(), entry.getValue().toString()); + return stringMap; + } + private Map asExpressionStrings(Map map) { HashMap values = new HashMap<>(); for (Map.Entry entry : map.entrySet()) @@ -246,16 +255,14 @@ public class ImportedModel { } /** Returns the expression this output references as an imported function */ - public ImportedFunction outputFunction(String outputName, String functionName) { - return new ImportedFunction(functionName, - new ArrayList<>(inputs.values()), - owner().expressions().get(outputs.get(outputName)), - inputMap(), - Optional.empty()); + public ImportedMlFunction outputFunction(String outputName, String functionName) { + return new ImportedMlFunction(functionName, + new ArrayList<>(inputs.values()), + owner().expressions().get(outputs.get(outputName)).getRoot().toString(), + asTensorTypeStrings(inputMap()), + Optional.empty()); } - // CFG - @Override public String toString() { return "signature '" + name + "'"; } @@ -266,37 +273,4 @@ public class ImportedModel { } - // CFG - public static class ImportedFunction { - - private final String name; - private final List arguments; - private final Map argumentTypes; - private final String expression; - private final Optional returnType; - - public ImportedFunction(String name, List arguments, RankingExpression expression, - Map argumentTypes, Optional returnType) { - this.name = name; - this.arguments = arguments; - this.expression = expression.getRoot().toString(); - this.argumentTypes = asStrings(argumentTypes); - this.returnType = returnType.map(TensorType::toString); - } - - private static Map asStrings(Map map) { - Map stringMap = new HashMap<>(); - for (Map.Entry entry : map.entrySet()) - stringMap.put(entry.getKey(), entry.getValue().toString()); - return stringMap; - } - - public String name() { return name; } - public List arguments() { return Collections.unmodifiableList(arguments); } - public Map argumentTypes() { return Collections.unmodifiableMap(argumentTypes); } - public String expression() { return expression; } - public Optional returnType() { return returnType; } - - } - } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java deleted file mode 100644 index bfdaaca1dd7..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer; - -import com.google.common.collect.ImmutableMap; -import com.yahoo.path.Path; - -import java.io.File; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; - -/** - * All models imported from the models/ directory in the application package. - * If this is empty it may be due to either not having any models in the application package, - * or this being created for a ZooKeeper application package, which does not have imported models. - * - * @author bratseth - */ -public class ImportedModels { - - /** All imported models, indexed by their names */ - private final ImmutableMap importedModels; - - /** Create a null imported models */ - public ImportedModels() { - importedModels = ImmutableMap.of(); - } - - public ImportedModels(File modelsDirectory, Collection importers) { - Map models = new HashMap<>(); - - // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models, importers); - importedModels = ImmutableMap.copyOf(models); - } - - private static void importRecursively(File dir, - Map models, - Collection importers) { - if ( ! dir.isDirectory()) return; - - Arrays.stream(dir.listFiles()).sorted().forEach(child -> { - Optional importer = findImporterOf(child, importers); - if (importer.isPresent()) { - String name = toName(child); - ImportedModel existing = models.get(name); - if (existing != null) - throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + - " both resolve to the model name '" + name + "'"); - models.put(name, importer.get().importModel(name, child)); - } - else { - importRecursively(child, models, importers); - } - }); - } - - private static Optional findImporterOf(File path, Collection importers) { - return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); - } - - /** - * Returns the model at the given location in the application package. - * - * @param modelPath the path to this model (file or directory, depending on model type) - * under the application package, both from the root or relative to the - * models directory works - * @return the model at this path or null if none - */ - // CFG - public ImportedModel get(File modelPath) { - return importedModels.get(toName(modelPath)); - } - - public ImportedModel get(String modelName) { - return importedModels.get(modelName); - } - - /** Returns an immutable collection of all the imported models */ - // CFG - public Collection all() { - return importedModels.values(); - } - - private static String toName(File modelFile) { - Path modelPath = Path.fromString(modelFile.toString()); - if (modelFile.isFile()) - modelPath = stripFileEnding(modelPath); - String localPath = concatenateAfterModelsDirectory(modelPath); - return localPath.replace('.', '_'); - } - - private static Path stripFileEnding(Path path) { - int dotIndex = path.last().lastIndexOf("."); - if (dotIndex <= 0) return path; - return path.withLast(path.last().substring(0, dotIndex)); - } - - private static String concatenateAfterModelsDirectory(Path path) { - boolean afterModels = false; - StringBuilder result = new StringBuilder(); - for (String element : path.elements()) { - if (afterModels) result.append(element).append("_"); - if (element.equals("models")) afterModels = true; - } - return result.substring(0, result.length()-1); - } - -} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 8a885938bf9..0200a9032a5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; +import com.yahoo.config.model.api.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -27,20 +28,22 @@ import java.util.logging.Logger; * * @author lesters */ -public abstract class ModelImporter { +public abstract class ModelImporter implements MlModelImporter { private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); /** Returns whether the file or directory at the given path is of the type which can be imported by this */ + @Override public abstract boolean canImport(String modelPath); - /** Imports the given model */ - public abstract ImportedModel importModel(String modelName, String modelPath); - - final ImportedModel importModel(String modelName, File modelPath) { + @Override + public final ImportedModel importModel(String modelName, File modelPath) { return importModel(modelName, modelPath.toString()); } + /** Imports the given model */ + public abstract ImportedModel importModel(String modelName, String modelPath); + /** * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java deleted file mode 100644 index 4473f306dcd..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * Model integration. - * - * CAUTION!: Config models depends on this API. It cannot be changed without ensuring compatibility with - * old config models. - */ -@ExportPackage -package ai.vespa.rankingexpression.importer; - -import com.yahoo.osgi.annotation.ExportPackage; -- cgit v1.2.3