diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-29 16:01:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-29 16:01:28 +0200 |
commit | 6958f2a641eaad0c61249e7bca887a1405e17d02 (patch) | |
tree | 49c8299f2fffbff2d814eea0cf86e8081d90cce8 | |
parent | 06323aff51bf054d64ef2bea001917a22433717f (diff) | |
parent | c27d2709eea8b697f8e099c1af12872bb7a75610 (diff) |
Merge pull request #6722 from vespa-engine/revert-6713-bratseth/generate-rank-profiles-for-all-models-part-10
Revert "Read stored models from Zk package for global rank profiles"
13 files changed, 274 insertions, 335 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index c7258d8aede..f926259f115 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -159,10 +159,12 @@ public interface ApplicationPackage { } /** - * Returns inforamtion about a file + * Gets a file from the root of the application package + * * * @param relativePath the relative path of the file within this application package. - * @return information abut the file, returned whether or not the file exists + * @return information abut the file + * @throws IllegalArgumentException if the given path does not exist */ ApplicationFile getFile(Path relativePath); diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java index 757dab4cbf3..7404ae14a5d 100644 --- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java +++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java @@ -16,17 +16,11 @@ import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.config.application.api.ApplicationPackage; -import java.io.BufferedInputStream; import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; import java.io.IOException; -import java.io.InputStream; import java.io.Reader; import java.io.StringReader; -import java.io.UncheckedIOException; import java.util.*; -import java.util.stream.Collectors; /** * For testing purposes only @@ -119,7 +113,7 @@ public class MockApplicationPackage implements ApplicationPackage { @Override public ApplicationFile getFile(Path file) { - return new MockApplicationFile(file, Path.fromString(root.toString())); + throw new UnsupportedOperationException(); } @Override @@ -306,122 +300,4 @@ public class MockApplicationPackage implements ApplicationPackage { return xmlStringWithIdAttribute.substring(idStart + 4, idEnd - 1); } - public static class MockApplicationFile extends ApplicationFile { - - /** The path to the application package root */ - private final Path root; - - /** The File pointing to the actual file represented by this */ - private final File file; - - public MockApplicationFile(Path filePath, Path applicationPackagePath) { - super(filePath); - this.root = applicationPackagePath; - file = applicationPackagePath.append(filePath).toFile(); - } - - @Override - public boolean isDirectory() { - return file.isDirectory(); - } - - @Override - public boolean exists() { - return file.exists(); - } - - @Override - public Reader createReader() throws FileNotFoundException { - try { - if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); - return IOUtils.createReader(file, "UTF-8"); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public InputStream createInputStream() throws FileNotFoundException { - try { - if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); - return new BufferedInputStream(new FileInputStream(file)); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public ApplicationFile createDirectory() { - file.mkdirs(); - return this; - } - - @Override - public ApplicationFile writeFile(Reader input) { - try { - IOUtils.writeFile(file, IOUtils.readAll(input), false); - return this; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public ApplicationFile appendFile(String value) { - try { - IOUtils.writeFile(file, value, true); - return this; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public List<ApplicationFile> listFiles(PathFilter filter) { - if ( ! isDirectory()) return Collections.emptyList(); - return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString()))) - .map(f -> new MockApplicationFile(asApplicationRelativePath(f), root)) - .collect(Collectors.toList()); - } - - @Override - public ApplicationFile delete() { - file.delete(); - return this; - } - - @Override - public MetaData getMetaData() { - throw new UnsupportedOperationException(); - } - - @Override - public int compareTo(ApplicationFile other) { - return this.getPath().getName().compareTo((other).getPath().getName()); - } - - /** Strips the application package root path prefix from the path of the given file */ - private Path asApplicationRelativePath(File file) { - Path path = Path.fromString(file.toString()); - - Iterator<String> pathIterator = path.iterator(); - // Skip the path elements this shares with the root - for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) { - String rootElement = rootIterator.next(); - String pathElement = pathIterator.next(); - if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken"); - } - // Build a path from the remaining - Path relative = Path.fromString(""); - while (pathIterator.hasNext()) - relative = relative.append(pathIterator.next()); - return relative; - } - - } - } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 935b9200868..0911f567fa1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -1,6 +1,5 @@ package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; @@ -17,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -38,6 +38,7 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; import java.io.Reader; import java.io.StringReader; @@ -64,91 +65,49 @@ import java.util.stream.Collectors; public class ConvertedModel { private final String modelName; - private final String modelDescription; - private final ImmutableMap<String, RankingExpression> expressions; - - /** The source importedModel, or empty if this was created from a stored converted model */ - private final Optional<ImportedModel> sourceModel; - - private ConvertedModel(String modelName, - String modelDescription, - Map<String, RankingExpression> expressions, - Optional<ImportedModel> sourceModel) { - this.modelName = modelName; - this.modelDescription = modelDescription; - this.expressions = ImmutableMap.copyOf(expressions); - this.sourceModel = sourceModel; - } + private final Path modelPath; /** - * Create and store a converted model for a rank profile given from either an imported model, - * or (if unavailable) from stored application package data. + * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots + * where the first part is always the model name, the second the signature or (if none) + * expression name (if more than one), and the third is the output name (if any). */ - public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) { - File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath); - if (sourceModel.exists()) - return fromSource(toModelName(modelPath), - modelPath.toString(), - context.rankProfile(), - context.queryProfiles(), - context.importedModels().get(sourceModel)); // TODO: Convert to name here, make sure its done just one way - else - return fromStore(toModelName(modelPath), - modelPath.toString(), - context.rankProfile()); - } - - public static ConvertedModel fromSource(String modelName, - String modelDescription, - RankProfile rankProfile, - QueryProfileRegistry queryProfileRegistry, - ImportedModel importedModel) { - ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); - return new ConvertedModel(modelName, - modelDescription, - convertAndStore(importedModel, rankProfile, queryProfileRegistry, modelStore), - Optional.of(importedModel)); - } - - public static ConvertedModel fromStore(String modelName, - String modelDescription, - RankProfile rankProfile) { - ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName); - return new ConvertedModel(modelName, - modelDescription, - convertStored(modelStore, rankProfile), - Optional.empty()); - } + private final Map<String, RankingExpression> expressions; /** - * Returns all the output expressions of this indexed by name. The names consist of one or two parts - * separated by dot, where the first part is the signature name - * if signatures are used, or the expression name if signatures are not used and there are multiple - * expressions, and the second is the output name if signature names are used. + * Create a converted model for a rank profile given from either an imported model, + * or (if unavailable) from stored application package data. */ - public Map<String, RankingExpression> expressions() { return expressions; } + public ConvertedModel(Path modelPath, RankProfileTransformContext context) { + this.modelPath = modelPath; + this.modelName = toModelName(modelPath); + ModelStore store = new ModelStore(context.rankProfile().applicationPackage(), modelPath); + if ( store.hasSourceModel()) + expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels()); + else + expressions = transformFromStoredModel(store, context.rankProfile()); + } - /** - * Returns the expression matching the given arguments. - */ - public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { - RankingExpression expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we can verify - verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles()); - return expression.getRoot(); + private Map<String, RankingExpression> convertModel(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ImportedModels importedModels) { + ImportedModel model = importedModels.get(store.sourceModelFile()); + return transformFromImportedModel(model, store, profile, queryProfiles); } - private RankingExpression selectExpression(FeatureArguments arguments) { + /** Returns the expression matching the given arguments */ + public ExpressionNode expression(FeatureArguments arguments) { if (expressions.isEmpty()) throw new IllegalArgumentException("No expressions available in " + this); RankingExpression expression = expressions.get(arguments.toName()); - if (expression != null) return expression; + if (expression != null) return expression.getRoot(); if ( ! arguments.signature().isPresent()) { if (expressions.size() > 1) throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix()); - return expressions.values().iterator().next(); + return expressions.values().iterator().next().getRoot(); } if ( ! arguments.output().isPresent()) { @@ -160,23 +119,21 @@ public class ConvertedModel { if (entriesWithTheRightPrefix.size() > 1) throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() + missingExpressionMessageSuffix()); - return entriesWithTheRightPrefix.get(0).getValue(); + return entriesWithTheRightPrefix.get(0).getValue().getRoot(); } throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix()); } private String missingExpressionMessageSuffix() { - return "' in model '" + modelDescription + "'. " + + return "' in model '" + this.modelPath + "'. " + "Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", ")); } - // ----------------------- Static model conversion/storage below here - - private static Map<String, RankingExpression> convertAndStore(ImportedModel model, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelStore store) { + private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles) { // Add constants Set<String> constantsReplacedByMacros = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -204,21 +161,22 @@ public class ConvertedModel { return expressions; } - private static void addExpression(RankingExpression expression, - String expressionName, - Set<String> constantsReplacedByMacros, - ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - Map<String, RankingExpression> expressions) { + private void addExpression(RankingExpression expression, + String expressionName, + Set<String> constantsReplacedByMacros, + ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + Map<String, RankingExpression> expressions) { expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + verifyRequiredMacros(expression, model, profile, queryProfiles); reduceBatchDimensions(expression, model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } - private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) { + private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -234,12 +192,12 @@ public class ConvertedModel { return store.readExpressions(); } - private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { store.writeSmallConstant(constantName, constantValue); profile.addConstant(constantName, asValue(constantValue)); } - private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, Set<String> constantsReplacedByMacros, String constantName, Tensor constantValue) { RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); @@ -259,7 +217,7 @@ public class ConvertedModel { } } - private static void transformGeneratedMacro(ModelStore store, + private void transformGeneratedMacro(ModelStore store, Set<String> constantsReplacedByMacros, String macroName, RankingExpression expression) { @@ -268,7 +226,7 @@ public class ConvertedModel { store.writeMacro(macroName, expression); } - private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { if (profile.getMacros().containsKey(macroName)) { if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression)) throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile + @@ -285,8 +243,8 @@ public class ConvertedModel { * Verify that the macros referred in the given expression exists in the given rank profile, * and return tensors of the types specified in requiredMacros. */ - private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { + private void verifyRequiredMacros(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> macroNames = new HashSet<>(); addMacroNamesIn(expression.getRoot(), macroNames, model); for (String macroName : macroNames) { @@ -314,7 +272,7 @@ public class ConvertedModel { } } - private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { + private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { return "The required type of this is " + requiredType + ", but this macro returns " + actualType + (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + "in query profile types - see the documentation." @@ -324,7 +282,7 @@ public class ConvertedModel { /** * Add the generated macros to the rank profile */ - private static void addGeneratedMacros(ImportedModel model, RankProfile profile) { + private void addGeneratedMacros(ImportedModel model, RankProfile profile) { model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy())); } @@ -333,8 +291,8 @@ public class ConvertedModel { * macro specifies that a single exemplar should be evaluated, we can * reduce the batch dimension out. */ - private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { + private void reduceBatchDimensions(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); TensorType typeBeforeReducing = expression.getRoot().type(typeContext); @@ -361,8 +319,8 @@ public class ConvertedModel { expression.setRoot(root); } - private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, - TypeContext<Reference> typeContext) { + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, + TypeContext<Reference> typeContext) { if (node instanceof TensorFunctionNode) { TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); if (tensorFunction instanceof Rename) { @@ -392,7 +350,7 @@ public class ConvertedModel { return node; } - private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { TensorFunction result = function; TensorType type = function.type(context); if (type.dimensions().size() > 1) { @@ -414,7 +372,7 @@ public class ConvertedModel { * for any following computation of the tensor. */ // TODO: determine when this is not necessary! - private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { if (after.equals(before)) { return node; } @@ -441,14 +399,14 @@ public class ConvertedModel { * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. * This method does that for the given expression and returns the result. */ - private static RankingExpression replaceConstantsByMacros(RankingExpression expression, + private RankingExpression replaceConstantsByMacros(RankingExpression expression, Set<String> constantsReplacedByMacros) { if (constantsReplacedByMacros.isEmpty()) return expression; return new RankingExpression(expression.getName(), replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); } - private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { if (node instanceof ReferenceNode) { Reference reference = ((ReferenceNode)node).reference(); if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { @@ -466,7 +424,7 @@ public class ConvertedModel { return node; } - private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { + private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; if (referenceNode.getOutput() == null) { // macro references cannot specify outputs @@ -482,7 +440,7 @@ public class ConvertedModel { } } - private static Value asValue(Tensor tensor) { + private Value asValue(Tensor tensor) { if (tensor.type().rank() == 0) return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors else @@ -497,13 +455,6 @@ public class ConvertedModel { public String toString() { return "model '" + modelName + "'"; } /** - * Returns the directory which contains the source model to use for these arguments - */ - public static File sourceModelFile(ApplicationPackage application, Path sourceModelPath) { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(sourceModelPath)); - } - - /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ static class ModelStore { @@ -511,9 +462,20 @@ public class ConvertedModel { private final ApplicationPackage application; private final ModelFiles modelFiles; - ModelStore(ApplicationPackage application, String modelName) { + ModelStore(ApplicationPackage application, Path modelPath) { this.application = application; - this.modelFiles = new ModelFiles(modelName); + this.modelFiles = new ModelFiles(modelPath); + } + + public boolean hasSourceModel() { + return sourceModelFile().exists(); + } + + /** + * Returns the directory which contains the source model to use for these arguments + */ + public File sourceModelFile() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath())); } /** @@ -586,7 +548,7 @@ public class ConvertedModel { List<RankingConstant> readLargeConstants() { try { List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsInfoPath()).listFiles()) { + for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) { String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); } @@ -604,13 +566,13 @@ public class ConvertedModel { * @return the path to the stored constant, relative to the application package root */ Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = modelFiles.largeConstantsContentPath(); + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants"); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); // Remember the constant in a file we replicate in ZooKeeper - application.getFile(modelFiles.largeConstantsInfoPath().append(name + ".constant")) + application.getFile(modelFiles.largeConstantsPath().append(name + ".constant")) .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution @@ -647,8 +609,8 @@ public class ConvertedModel { public void writeSmallConstant(String name, Tensor constant) { // Secret file format for remembering constants: application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); + constant.type().toString() + "\t" + + constant.toString() + "\n"); } /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ @@ -670,24 +632,40 @@ public class ConvertedModel { } } + private void close(Reader reader) { + try { + if (reader != null) + reader.close(); + } + catch (IOException e) { + // ignore + } + } + } static class ModelFiles { - String modelName; + Path modelPath; - public ModelFiles(String modelName) { - this.modelName = modelName; + public ModelFiles(Path modelPath) { + this.modelPath = modelPath; } + /** Returns modelPath with slashes replaced by underscores */ + public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + /** Files stored below this path will be replicated in zookeeper */ public Path storedModelReplicatedPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName); + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath()); } - /** Files stored below this path will not be replicated in zookeeper */ + /** Files stored below this path will not be replicated */ public Path storedModelPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName); + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath()); } public Path expressionPath(String name) { @@ -703,12 +681,7 @@ public class ConvertedModel { } /** Path to the large (ranking) constants directory */ - public Path largeConstantsContentPath() { - return storedModelPath().append("constants"); - } - - /** Path to the large (ranking) constants directory */ - public Path largeConstantsInfoPath() { + public Path largeConstantsPath() { return storedModelReplicatedPath().append("constants"); } @@ -722,19 +695,27 @@ public class ConvertedModel { /** Encapsulates the arguments of a specific model output */ static class FeatureArguments { + private final String modelName; + private final Path modelPath; + /** Optional arguments */ private final Optional<String> signature, output; public FeatureArguments(Arguments arguments) { - this(optionalArgument(1, arguments), + this(Path.fromString(asString(arguments.expressions().get(0))), + optionalArgument(1, arguments), optionalArgument(2, arguments)); } - public FeatureArguments(Optional<String> signature, Optional<String> output) { + public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { + this.modelPath = modelPath; + this.modelName = toModelName(modelPath); this.signature = signature; this.output = output; } + public Path modelPath() { return modelPath; } + public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 229ae0ebaaf..36dc200f3c9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -13,7 +13,6 @@ import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** * Replaces instances of the onnx(model-path, output) @@ -42,11 +41,10 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans if ( ! feature.getName().equals("onnx")) return feature; try { - // TODO: Put modelPath in FeatureArguments instead Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); - return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); + convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); + return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use Onnx model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index bcb8ef1521d..619c13da764 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -41,8 +41,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); - return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); + convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); + return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index b4a5069b9d6..e6b08ab0350 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -43,8 +43,8 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); - return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); + convertedXGBoostModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); + return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index ec92905cd1f..73dd60f63eb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -8,7 +8,6 @@ import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; import com.yahoo.config.ConfigurationRuntimeException; import com.yahoo.config.FileReference; -import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.application.api.ValidationId; @@ -27,13 +26,11 @@ import com.yahoo.config.model.producer.AbstractConfigProducerRoot; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.log.LogLevel; -import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; -import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; @@ -165,9 +162,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri this.applicationPackage = deployState.getApplicationPackage(); root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this); - createGlobalRankProfiles(deployState.getImportedModels(), - deployState.rankProfileRegistry(), - deployState.getQueryProfiles().getRegistry()); + createGlobalRankProfiles(deployState.getImportedModels(), deployState.rankProfileRegistry()); this.rankProfileList = new RankProfileList(null, // null search -> global AttributeFields.empty, deployState.rankProfileRegistry(), @@ -225,30 +220,14 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri * Creates a rank profile not attached to any search definition, for each imported model in the application package */ private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels, - RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfiles) { + RankProfileRegistry rankProfileRegistry) { List<RankProfile> profiles = new ArrayList<>(); - if ( ! importedModels.all().isEmpty()) { // models/ directory is available - for (ImportedModel model : importedModels.all()) { - RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); - rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromSource(model.name(), model.name(), profile, queryProfiles, model); - for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { - profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); - } - } - } - else { // generated and saved model information may be available instead - ApplicationFile generatedModelsDir = applicationPackage.getFile(ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR); - for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) { - String modelName = generatedModelDir.getPath().last(); - RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); - rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, modelName, profile); - for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { - profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue()); - } + for (ImportedModel model : importedModels.all()) { + RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); + for (Pair<String, RankingExpression> entry : model.outputExpressions()) { + profile.addMacro(entry.getFirst(), false).setRankingExpression(entry.getSecond()); } + rankProfileRegistry.add(profile); } return ImmutableList.copyOf(profiles); } diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java index 8558ccc44bd..d06752c9b6d 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java @@ -2,14 +2,10 @@ package com.yahoo.config.model; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; -import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.io.IOUtils; -import com.yahoo.path.Path; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ContainerCluster; -import org.junit.After; import org.junit.Test; import org.xml.sax.SAXException; @@ -26,16 +22,11 @@ import static org.junit.Assert.assertTrue; */ public class ModelEvaluationTest { - private static final String appDir = "src/test/cfg/application/ml_serving"; - - @After - public void removeGeneratedModelFiles() { - IOUtils.recursiveDeleteDir(Path.fromString(appDir).append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - } + private static final String TESTDIR = "src/test/cfg/application/"; @Test public void testMl_ServingApplication() throws SAXException, IOException { - ApplicationPackageTester tester = ApplicationPackageTester.create(appDir); + ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "ml_serving"); VespaModel model = new VespaModel(tester.app()); ContainerCluster cluster = model.getContainerClusters().get("container"); RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 04a6f953bb6..815a01cdb99 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -142,6 +142,7 @@ public class RankingExpressionWithOnnxTestCase { } } + @Test public void testOnnxReferenceWithWrongMacroType() { try { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 28fcf871cf3..c317f07b87a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -403,7 +403,7 @@ public class RankingExpressionWithTensorFlowTestCase { */ private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { try { - Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax_saved/constants").append(name + ".tbf"); + Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); RankingConstant rankingConstant = search.search().rankingConstants().get(name); assertEquals(name, rankingConstant.getName()); assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); @@ -485,7 +485,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Override public ApplicationFile getFile(Path file) { - return new MockApplicationFile(file, Path.fromString(root().toString())); + return new StoringApplicationPackageFile(file, Path.fromString(root().toString())); } @Override @@ -505,4 +505,123 @@ public class RankingExpressionWithTensorFlowTestCase { } + static class StoringApplicationPackageFile extends ApplicationFile { + + /** The path to the application package root */ + private final Path root; + + /** The File pointing to the actual file represented by this */ + private final File file; + + StoringApplicationPackageFile(Path filePath, Path applicationPackagePath) { + super(filePath); + this.root = applicationPackagePath; + file = applicationPackagePath.append(filePath).toFile(); + } + + @Override + public boolean isDirectory() { + return file.isDirectory(); + } + + @Override + public boolean exists() { + return file.exists(); + } + + @Override + public Reader createReader() throws FileNotFoundException { + try { + if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); + return IOUtils.createReader(file, "UTF-8"); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public InputStream createInputStream() throws FileNotFoundException { + try { + if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist"); + return new BufferedInputStream(new FileInputStream(file)); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ApplicationFile createDirectory() { + file.mkdirs(); + return this; + } + + @Override + public ApplicationFile writeFile(Reader input) { + try { + IOUtils.writeFile(file, IOUtils.readAll(input), false); + return this; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public ApplicationFile appendFile(String value) { + try { + IOUtils.writeFile(file, value, true); + return this; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public List<ApplicationFile> listFiles(PathFilter filter) { + if ( ! isDirectory()) return Collections.emptyList(); + return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString()))) + .map(f -> new StoringApplicationPackageFile(asApplicationRelativePath(f), + root)) + .collect(Collectors.toList()); + } + + @Override + public ApplicationFile delete() { + file.delete(); + return this; + } + + @Override + public MetaData getMetaData() { + throw new UnsupportedOperationException(); + } + + @Override + public int compareTo(ApplicationFile other) { + return this.getPath().getName().compareTo((other).getPath().getName()); + } + + /** Strips the application package root path prefix from the path of the given file */ + private Path asApplicationRelativePath(File file) { + Path path = Path.fromString(file.toString()); + + Iterator<String> pathIterator = path.iterator(); + // Skip the path elements this shares with the root + for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) { + String rootElement = rootIterator.next(); + String pathElement = pathIterator.next(); + if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken"); + } + // Build a path from the remaining + Path relative = Path.fromString(""); + while (pathIterator.hasNext()) + relative = relative.append(pathIterator.next()); + return relative; + } + + } + } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index ca5cdc3cc56..50e83fd1a1e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -41,8 +41,8 @@ public class ModelsEvaluator extends AbstractComponent { Model requireModel(String name) { Model model = models.get(name); if (model == null) - throw new IllegalArgumentException("No model named '" + name + "'. Available models: " + - String.join(", ", models.keySet())); + throw new IllegalArgumentException("No model named '" + name + ". Available models: " + + models.keySet().stream().collect(Collectors.joining(", "))); return model; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index f7fe91cb56f..6716993e1dd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -98,33 +98,33 @@ public class ImportedModel { void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } /** - * Returns all the output expressions of this indexed by name. The names consist of one or two parts - * separated by dot, where the first part is the signature name + * Returns all the outputs of this by name. The names consist of one to three parts + * separated by dot, where the first part is the model name, the second is the signature name * if signatures are used, or the expression name if signatures are not used and there are multiple - * expressions, and the second is the output name if signature names are used. + * expressions, and the third is the output name if signature names are used. */ public List<Pair<String, RankingExpression>> outputExpressions() { - List<Pair<String, RankingExpression>> expressions = new ArrayList<>(); + List<Pair<String, RankingExpression>> names = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), - expressions().get(outputEntry.getValue()))); + names.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), + expressions().get(outputEntry.getValue()))); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - expressions.add(new Pair<>(signatureEntry.getKey(), - expressions().get(signatureEntry.getKey()))); + names.add(new Pair<>(signatureEntry.getKey(), + expressions().get(signatureEntry.getKey()))); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { - Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue())); + Map.Entry<String, RankingExpression> singleEntry = expressions.entrySet().iterator().next(); + names.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue())); + names.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue())); } } } - return expressions; + return names; } /** diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java index 32d33622a33..b1714b49256 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -10,9 +10,7 @@ import java.util.Collection; import java.util.Optional; /** - * All models imported from the models/ directory in the application package. - * If this is empty it may be due to either not having any models in the application package, - * or this being created for a ZooKeeper application package, which does not have imported models. + * All models imported from the models/ directory in the application package * * @author bratseth */ @@ -56,22 +54,16 @@ public class ImportedModels { } /** - * Returns the model at the given location in the application package. + * Returns the model at the given location in the application package (lazily loaded), * - * @param modelPath the path to this model (file or directory, depending on model type) - * under the application package, both from the root or relative to the - * models directory works - * @return the model at this path or null if none + * @param modelPath the full path to this model (file or directory, depending on model type) + * under the application package + * @throws IllegalArgumentException if the model cannot be loaded */ public ImportedModel get(File modelPath) { - System.out.println("Name from " + modelPath + ": " + toName(modelPath)); return importedModels.get(toName(modelPath)); } - public ImportedModel get(String modelName) { - return importedModels.get(modelName); - } - /** Returns an immutable collection of all the imported models */ public Collection<ImportedModel> all() { return importedModels.values(); |