diff options
15 files changed, 350 insertions, 210 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java index 36009682022..7ca9bcf48f3 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java @@ -45,6 +45,7 @@ import java.security.MessageDigest; import java.util.*; import java.util.jar.JarFile; import java.util.logging.Logger; +import java.util.stream.Collectors; import static com.yahoo.text.Lowercase.toLowerCase; @@ -164,7 +165,7 @@ public class FilesApplicationPackage implements ApplicationPackage { return metaData; } - private List<NamedReader> getFiles(Path relativePath,String namePrefix,String suffix,boolean recurse) { + private List<NamedReader> getFiles(Path relativePath, String namePrefix, String suffix, boolean recurse) { try { List<NamedReader> readers=new ArrayList<>(); File dir = new File(appDir, relativePath.getRelative()); diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index dd54fe11c39..a71a0878d3d 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -88,12 +88,14 @@ public interface ApplicationPackage { /** * Contents of services.xml. Caller must close reader after use. + * * @return a Reader, or null if no services.xml/vespa-services.xml present */ Reader getServices(); /** * Contents of hosts.xml. Caller must close reader after use. + * * @return a Reader, or null if no hosts.xml/vespa-hosts.xml present */ Reader getHosts(); @@ -160,8 +162,8 @@ public interface ApplicationPackage { * Gets a file from the root of the application package * * - * @param relativePath The relative path of the file within this application package. - * @return reader for file + * @param relativePath the relative path of the file within this application package. + * @return information abut the file * @throws IllegalArgumentException if the given path does not exist */ ApplicationFile getFile(Path relativePath); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java index 7b4d70d85b1..6311751bb88 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java @@ -52,11 +52,11 @@ public class RankProfileRegistry { } private void checkForDuplicateRankProfile(RankProfile rankProfile) { - final String rankProfileName = rankProfile.getName(); + String rankProfileName = rankProfile.getName(); RankProfile existingRangProfileWithSameName = rankProfiles.get(rankProfile.getSearch()).get(rankProfileName); if (existingRangProfileWithSameName == null) return; - if (!overridableRankProfileNames.contains(rankProfileName)) { + if ( ! overridableRankProfileNames.contains(rankProfileName)) { throw new IllegalArgumentException("Cannot add rank profile '" + rankProfileName + "' in search definition '" + rankProfile.getSearch().getName() + "', since it already exists"); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 867740c7912..d85d0983509 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -1,11 +1,11 @@ package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.base.Joiner; import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; +import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.FeatureNames; @@ -39,11 +39,14 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; +import java.io.Reader; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -62,69 +65,155 @@ import java.util.stream.Collectors; */ public class ConvertedModel { - private final ExpressionNode convertedExpression; + private final String modelName; + private final Path modelPath; - public ConvertedModel(FeatureArguments arguments, + /** + * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots + * where the first part is always the model name, the second the signature or (if none) + * expression name (if more than one), and the third is the output name (if any). + */ + private final Map<String, RankingExpression> expressions; + + public ConvertedModel(Path modelPath, RankProfileTransformContext context, ModelImporter modelImporter, - Map<Path, ImportedModel> importedModels) { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); + FeatureArguments arguments) { // TODO: Remove + this.modelPath = modelPath; + this.modelName = toModelName(modelPath); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory - convertedExpression = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, importedModels); + expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, arguments); else - convertedExpression = transformFromStoredModel(store, context.rankProfile()); + expressions = transformFromStoredModel(store, context.rankProfile()); } - private ExpressionNode importModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelImporter modelImporter, - Map<Path, ImportedModel> importedModels) { - ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), - k -> modelImporter.importModel(store.arguments().modelName(), - store.modelDir())); - return transformFromImportedModel(model, store, profile, queryProfiles); + private Map<String, RankingExpression> importModel(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelImporter modelImporter, + FeatureArguments arguments) { + ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.modelDir()); + return transformFromImportedModel(model, store, profile, queryProfiles, arguments); + } + + /** Returns the expression matching the given arguments */ + public ExpressionNode expression(FeatureArguments arguments) { + if (expressions.isEmpty()) + throw new IllegalArgumentException("No expressions available in " + this); + + RankingExpression expression = expressions.get(arguments.toName()); + if (expression != null) return expression.getRoot(); + + if ( ! arguments.signature().isPresent()) { + if (expressions.size() > 1) + throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix()); + return expressions.values().iterator().next().getRoot(); + } + + if ( ! arguments.output().isPresent()) { + List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix = + expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(modelName + "." + arguments.signature().get() + ".")).collect(Collectors.toList()); + if (entriesWithTheRightPrefix.size() < 1) + throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + + missingExpressionMessageSuffix()); + if (entriesWithTheRightPrefix.size() > 1) + throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() + + missingExpressionMessageSuffix()); + return entriesWithTheRightPrefix.get(0).getValue().getRoot(); + } + + throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix()); } - public ExpressionNode expression() { return convertedExpression; } + private String missingExpressionMessageSuffix() { + return "' in model '" + this.modelPath + "'. " + + "Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", ")); + } - private ExpressionNode transformFromImportedModel(ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { + private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + FeatureArguments arguments) { // Add constants Set<String> constantsReplacedByMacros = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, constantsReplacedByMacros, k, v)); - // Find the specified expression - ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature()); - String output = chooseOutput(signature, store.arguments().output()); - if (signature.skippedOutputs().containsKey(output)) { - String message = "Could not import model output '" + output + "'"; - if (!signature.skippedOutputs().get(output).isEmpty()) { - message += ": " + signature.skippedOutputs().get(output); + // Add macros + addGeneratedMacros(model, profile); + + // Add expressions + Map<String, RankingExpression> expressions = new HashMap<>(); + for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) { + if ( ! matches(signatureEntry.getValue(), arguments, Optional.empty())) continue; + + for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) { + if ( ! matches(signatureEntry.getValue(), arguments, Optional.of(outputEntry.getKey()))) continue; + addExpression(model.expressions().get(outputEntry.getValue()), + modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); } - if (!signature.importWarnings().isEmpty()) { - message += ": " + String.join(", ", signature.importWarnings()); + if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs + addExpression(model.expressions().get(signatureEntry.getKey()), + modelName + "." + signatureEntry.getKey(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); } - throw new IllegalArgumentException(message); } + if (model.signatures().isEmpty()) { // fallback: Model without signatures + if (model.expressions().size() == 1) { // Use just model name + addExpression(model.expressions().values().iterator().next(), + modelName, + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); + } + else { + for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) { + addExpression(expressionEntry.getValue(), + modelName + "." + expressionEntry.getKey(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); + } + } + } + + // Transform and save macro - must come after reading expressions due to optimization transforms + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - RankingExpression expression = model.expressions().get(output); + return expressions; + } + + private boolean matches(ImportedModel.Signature signature, FeatureArguments arguments, Optional<String> output) { + if ( ! modelName.equals(arguments.modelName)) return false; + if ( arguments.signature.isPresent() && ! signature.name().equals(arguments.signature().get())) return false; + if (output.isPresent() && arguments.output().isPresent() && ! output.get().matches(arguments.output().get())) return false; + return true; + } + + private void addExpression(RankingExpression expression, + String expressionName, + Set<String> constantsReplacedByMacros, + ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + Map<String, RankingExpression> expressions) { expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); + store.writeExpression(expressionName, expression); + expressions.put(expressionName, expression); } - ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -137,60 +226,7 @@ public class ConvertedModel { addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); } - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing signature, or the only signature if none is specified. - * Throws IllegalArgumentException in all other cases. - */ - private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) { - if ( ! signatureName.isPresent()) { - if (importResult.signatures().size() == 0) - throw new IllegalArgumentException("No signatures are available"); - if (importResult.signatures().size() > 1) - throw new IllegalArgumentException("Model has multiple signatures (" + - Joiner.on(", ").join(importResult.signatures().keySet()) + - "), one must be specified " + - "as a second argument to tensorflow()"); - return importResult.signatures().values().stream().findFirst().get(); - } - else { - ImportedModel.Signature signature = importResult.signatures().get(signatureName.get()); - if (signature == null) - throw new IllegalArgumentException("Model does not have the specified signature '" + - signatureName.get() + "'"); - return signature; - } - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (signature.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); - if (signature.outputs().size() > 1) - throw new IllegalArgumentException(signature + " has multiple outputs (" + - Joiner.on(", ").join(signature.outputs().keySet()) + - "), one must be specified " + - "as a third argument to tensorflow()"); - return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); - } - else { - String output = signature.outputs().get(outputName.get()); - if (output == null) { - if (signature.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - signature.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } + return store.readExpressions(); } private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { @@ -227,22 +263,15 @@ public class ConvertedModel { } private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { + if (profile.getMacros().containsKey(macroName)) throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); - } + profile.addMacro(macroName, false); // todo: inline if only used once RankProfile.Macro macro = profile.getMacros().get(macroName); macro.setRankingExpression(expression); macro.setTextualExpression(expression.getRoot().toString()); } - private String skippedOutputsDescription(ImportedModel.Signature signature) { - if (signature.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); - } - /** * Verify that the macros referred in the given expression exists in the given rank profile, * and return tensors of the types specified in requiredMacros. @@ -375,8 +404,8 @@ public class ConvertedModel { /** * If batch dimensions have been reduced away above, bring them back here * for any following computation of the tensor. - * Todo: determine when this is not necessary! */ + // TODO: determine when this is not necessary! private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { if (after.equals(before)) { return node; @@ -452,24 +481,29 @@ public class ConvertedModel { return new TensorValue(tensor); } + private static String toModelName(Path modelPath) { + return modelPath.toString().replace("/", "_"); + } + + @Override + public String toString() { return "model '" + modelName + "'"; } + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ static class ModelStore { private final ApplicationPackage application; - private final FeatureArguments arguments; + private final ModelFiles modelFiles; - ModelStore(ApplicationPackage application, FeatureArguments arguments) { + ModelStore(ApplicationPackage application, Path modelPath) { this.application = application; - this.arguments = arguments; + this.modelFiles = new ModelFiles(modelPath); } - public FeatureArguments arguments() { return arguments; } - public boolean hasStoredModel() { try { - return application.getFile(arguments.expressionPath()).exists(); + return application.getFileReference(modelFiles.storedModelPath()).exists(); } catch (UnsupportedOperationException e) { return false; @@ -480,40 +514,49 @@ public class ConvertedModel { * Returns the directory which contains the source model to use for these arguments */ public File modelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath())); } /** * Adds this expression to the application package, such that it can be read later. + * + * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part + * is always the model name */ - void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) + void writeExpression(String name, RankingExpression expression) { + application.getFile(modelFiles.expressionPath(name)) .writeFile(new StringReader(expression.getRoot().toString())); } - /** Reads the previously stored ranking expression for these arguments */ - RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + Map<String, RankingExpression> readExpressions() { + Map<String, RankingExpression> expressions = new HashMap<>(); + ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); + if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); + for (ApplicationFile expressionFile : expressionPath.listFiles()) { + try { + String name = expressionFile.getPath().getName(); + expressions.put(name, new RankingExpression(name, expressionFile.createReader())); + } + catch (FileNotFoundException e) { + throw new IllegalStateException("Expression file removed while reading: " + expressionFile, e); + } + catch (ParseException e) { + throw new IllegalStateException("Invalid stored expression in " + expressionFile, e); + } } + return expressions; } /** Adds this macro expression to the application package to it can be read later. */ void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" + expression.getRoot().toString() + "\n"); } /** Reads the previously stored macro expressions for these arguments */ List<Pair<String, RankingExpression>> readMacros() { try { - ApplicationFile file = application.getFile(arguments.macrosPath()); + ApplicationFile file = application.getFile(modelFiles.macrosPath()); if (!file.exists()) return Collections.emptyList(); List<Pair<String, RankingExpression>> macros = new ArrayList<>(); @@ -527,7 +570,7 @@ public class ConvertedModel { macros.add(new Pair<>(name, expression)); } catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + throw new IllegalStateException("Could not parse " + name, e); } } return macros; @@ -544,7 +587,7 @@ public class ConvertedModel { List<RankingConstant> readLargeConstants() { try { List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) { String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); } @@ -562,13 +605,13 @@ public class ConvertedModel { * @return the path to the stored constant, relative to the application package root */ Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants"); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + application.getFile(modelFiles.largeConstantsPath().append(name + ".constant")) .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution @@ -579,7 +622,7 @@ public class ConvertedModel { private List<Pair<String, Tensor>> readSmallConstants() { try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + ApplicationFile file = application.getFile(modelFiles.smallConstantsPath()); if (!file.exists()) return Collections.emptyList(); List<Pair<String, Tensor>> constants = new ArrayList<>(); @@ -604,7 +647,7 @@ public class ConvertedModel { */ public void writeSmallConstant(String name, Tensor constant) { // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" + constant.type().toString() + "\t" + constant.toString() + "\n"); } @@ -628,26 +671,24 @@ public class ConvertedModel { } } + private void close(Reader reader) { + try { + if (reader != null) + reader.close(); + } + catch (IOException e) { + // ignore + } + } + } - /** Encapsulates the arguments to the import feature */ - static class FeatureArguments { + static class ModelFiles { Path modelPath; - /** Optional arguments */ - Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - this(Path.fromString(asString(arguments.expressions().get(0))), - optionalArgument(1, arguments), - optionalArgument(2, arguments)); - } - - public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { + public ModelFiles(Path modelPath) { this.modelPath = modelPath; - this.signature = signature; - this.output = output; } /** Returns modelPath with slashes replaced by underscores */ @@ -655,37 +696,66 @@ public class ConvertedModel { /** Returns relative path to this model below the "models/" dir in the application package */ public Path modelPath() { return modelPath; } - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - /** Path to the small constants file */ + public Path storedModelPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath()); + } + + public Path expressionPath(String name) { + return storedModelPath().append("expressions").append(name); + } + + public Path expressionsPath() { + return storedModelPath().append("expressions"); + } + public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + return storedModelPath().append("constants.txt"); } /** Path to the large (ranking) constants directory */ public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + return storedModelPath().append("constants"); } /** Path to the macros file */ public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + return storedModelPath().append("macros.txt"); } - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); + } + + /** Encapsulates the arguments of a specific model output */ + static class FeatureArguments { + + private final String modelName; + private final Path modelPath; + + /** Optional arguments */ + private final Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { + this(Path.fromString(asString(arguments.expressions().get(0))), + optionalArgument(1, arguments), + optionalArgument(2, arguments)); } - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - signature.ifPresent(s -> fileName.append(s).append(".")); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); + public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { + this.modelPath = modelPath; + this.modelName = toModelName(modelPath); + this.signature = signature; + this.output = output; + } + + public Path modelPath() { return modelPath; } + + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + public String toName() { + return modelName + + (signature.isPresent() ? "." + signature.get() : "") + + (output.isPresent() ? "." + output.get() : ""); } private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { @@ -694,7 +764,7 @@ public class ConvertedModel { return Optional.of(asString(arguments.expressions().get(argumentIndex))); } - private static String asString(ExpressionNode node) { + public static String asString(ExpressionNode node) { if ( ! (node instanceof ConstantNode)) throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); return stripQuotes(((ConstantNode)node).sourceString()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index d31ffefde65..97395c1aad3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -31,7 +31,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans private final OnnxImporter onnxImporter = new OnnxImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ImportedModel> importedModels = new HashMap<>(); + private final Map<Path, ConvertedModel> convertedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -47,9 +47,9 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans if ( ! feature.getName().equals("onnx")) return feature; try { - ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()), - context, onnxImporter, importedModels); - return convertedModel.expression(); + Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter, new ConvertedModel.FeatureArguments(feature.getArguments()))); + return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use Onnx model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index d28299b1d30..b3778e2af84 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -28,7 +28,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ImportedModel> importedModels = new HashMap<>(); + private final Map<Path, ConvertedModel> convertedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -44,9 +44,9 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()), - context, tensorFlowImporter, importedModels); - return convertedModel.expression(); + Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter, new ConvertedModel.FeatureArguments(feature.getArguments()))); + return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index 4ae223ec3a5..62f43e15849 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -37,7 +37,8 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr try { ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments()); - ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); + ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), + arguments.modelPath()); RankingExpression expression = xgboostImporter.parseModel(store.modelDir().toString()); return expression.getRoot(); } catch (IllegalArgumentException | UncheckedIOException e) { @@ -48,7 +49,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An xgboost node must take an argument pointing to " + - "the xgboost model directory under [application]/models"); + "the xgboost model directory under [application]/models"); if (arguments.expressions().size() > 1) throw new IllegalArgumentException("An xgboost feature can have at most 1 argument"); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 0ce6129ef7f..ab689b88993 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -10,7 +10,9 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; @@ -25,6 +27,7 @@ class RankProfileSearchFixture { private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; private Search search; + private Map<String, RankProfile> compiledRankProfiles = new HashMap<>(); RankProfileSearchFixture(String rankProfiles) throws ParseException { this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles); @@ -54,25 +57,38 @@ class RankProfileSearchFixture { } public void assertFirstPhaseExpression(String expExpression, String rankProfile) { - assertEquals(expExpression, rankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString()); + assertEquals(expExpression, compiledRankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString()); } public void assertSecondPhaseExpression(String expExpression, String rankProfile) { - assertEquals(expExpression, rankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString()); + assertEquals(expExpression, compiledRankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString()); } public void assertRankProperty(String expValue, String name, String rankProfile) { - List<RankProfile.RankProperty> rankPropertyList = rankProfile(rankProfile).getRankPropertyMap().get(name); + List<RankProfile.RankProperty> rankPropertyList = compiledRankProfile(rankProfile).getRankPropertyMap().get(name); assertEquals(1, rankPropertyList.size()); assertEquals(expValue, rankPropertyList.get(0).getValue()); } - public void assertMacro(String expExpression, String macroName, String rankProfile) { - assertEquals(expExpression, rankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString()); + public void assertMacro(String expexctedExpression, String macroName, String rankProfile) { + assertEquals(expexctedExpression, + compiledRankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString()); } + public RankProfile compileRankProfile(String rankProfile) { + RankProfile compiled = rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry); + compiledRankProfiles.put(rankProfile, compiled); + return compiled; + } + + /** Returns the given uncompiled profile */ public RankProfile rankProfile(String rankProfile) { - return rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry); + return rankProfileRegistry.getRankProfile(search, rankProfile); + } + + /** Returns the given compiled profile, or null if not compiled yet or not present at all */ + public RankProfile compiledRankProfile(String rankProfile) { + return compiledRankProfiles.get(rankProfile); } public Search search() { return search; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index b2ef08dcc36..a7465fa9695 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -123,6 +123,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: onnx('mnist_softmax.onnx')" + " }\n" + " }"); + search.compileRankProfile("my_profile"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -164,7 +165,8 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx','y'): " + - "Model does not have the specified signature 'y'", + "No expressions available in model 'mnist_softmax.onnx'", +// "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax.onnx.default.add", Exceptions.toMessageString(expected)); } } @@ -220,7 +222,8 @@ public class RankingExpressionWithOnnxTestCase { String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_onnx_Variable, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))"; - RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile"); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", @@ -234,7 +237,8 @@ public class RankingExpressionWithOnnxTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication); + searchFromStored.compileRankProfile("my_profile"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", searchFromStored.search().getRankingConstants().get("mnist_softmax_onnx_Variable")); @@ -271,19 +275,19 @@ public class RankingExpressionWithOnnxTestCase { private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", - new StoringApplicationPackage(applicationDir)); + new StoringApplicationPackage(applicationDir)); } private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, String constant, String field) { return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder", - new StoringApplicationPackage(applicationDir)); + new StoringApplicationPackage(applicationDir)); } - private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) { try { return new RankProfileSearchFixture(application, application.getQueryProfiles(), - rankProfile, null, null); + rankProfile, null, null); } catch (ParseException e) { throw new IllegalArgumentException(e); @@ -297,7 +301,7 @@ public class RankingExpressionWithOnnxTestCase { String macroName, StoringApplicationPackage application) { try { - return new RankProfileSearchFixture( + RankProfileSearchFixture fixture = new RankProfileSearchFixture( application, application.getQueryProfiles(), " rank-profile my_profile {\n" + @@ -310,6 +314,8 @@ public class RankingExpressionWithOnnxTestCase { " }", constant, field); + fixture.compileRankProfile("my_profile"); + return fixture; } catch (ParseException e) { throw new IllegalArgumentException(e); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 7228af2b0de..29859817736 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -6,6 +6,7 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; +import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; @@ -22,10 +23,12 @@ import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; +import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.io.Reader; import java.io.UncheckedIOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; @@ -156,6 +159,7 @@ public class RankingExpressionWithTensorFlowTestCase { " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); + search.compileRankProfile("my_profile"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -196,7 +200,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_defaultz'): " + - "Model does not have the specified signature 'serving_defaultz'", + "No expressions available in model 'mnist_softmax_saved'", +// "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+ +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } @@ -212,7 +218,9 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved','serving_default','x'): " + - "Model does not have the specified output 'x'", + "No expressions available in model 'mnist_softmax_saved'", +// "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " + +// "Available expressions: mnist_softmax_saved.serving_default.y", Exceptions.toMessageString(expected)); } } @@ -268,7 +276,8 @@ public class RankingExpressionWithTensorFlowTestCase { String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_saved_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"; - RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + RankProfileSearchFixture search = fixtureWithUncompiled(rankProfile, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile"); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", @@ -282,7 +291,8 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfile, storedApplication); + searchFromStored.compileRankProfile("my_profile"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", searchFromStored.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); @@ -297,7 +307,7 @@ public class RankingExpressionWithTensorFlowTestCase { public void testTensorFlowReduceBatchDimension() { final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist_softmax/saved')"); + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", search, Optional.of(10L)); assertLargeConstant("mnist_softmax_saved_layer_Variable_read", search, Optional.of(7840L)); @@ -362,7 +372,7 @@ public class RankingExpressionWithTensorFlowTestCase { } private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { - Value value = search.rankProfile("my_profile").getConstants().get(name); + Value value = search.compiledRankProfile("my_profile").getConstants().get(name); assertNotNull(value); assertEquals(type, value.type()); } @@ -410,7 +420,7 @@ public class RankingExpressionWithTensorFlowTestCase { String macroName, StoringApplicationPackage application) { try { - return new RankProfileSearchFixture( + RankProfileSearchFixture fixture = new RankProfileSearchFixture( application, application.getQueryProfiles(), " rank-profile my_profile {\n" + @@ -423,13 +433,15 @@ public class RankingExpressionWithTensorFlowTestCase { " }", constant, field); + fixture.compileRankProfile("my_profile"); + return fixture; } catch (ParseException e) { throw new IllegalArgumentException(e); } } - private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + private RankProfileSearchFixture fixtureWithUncompiled(String rankProfile, StoringApplicationPackage application) { try { return new RankProfileSearchFixture(application, application.getQueryProfiles(), rankProfile, null, null); @@ -463,6 +475,21 @@ public class RankingExpressionWithTensorFlowTestCase { return new StoringApplicationPackageFile(file, Path.fromString(root.toString())); } + @Override + public List<NamedReader> getFiles(Path path, String suffix) { + List<NamedReader> readers = new ArrayList<>(); + for (File file : getFileReference(path).listFiles()) { + if ( ! file.getName().endsWith(suffix)) continue; + try { + readers.add(new NamedReader(file.getName(), new FileReader(file))); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + return readers; + } + } static class StoringApplicationPackageFile extends ApplicationFile { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java index dba2bdbfbbf..0866d3192cf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java @@ -25,6 +25,7 @@ public class RankingExpressionWithTensorTestCase { " }\n" + " }\n" + " }"); + f.compileRankProfile("my_profile"); f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile"); f.assertRankProperty("{{x:1,y:2}:1.0,{x:2,y:1}:2.0}", "constant(my_tensor).value", "my_profile"); f.assertRankProperty("tensor(x{},y{})", "constant(my_tensor).type", "my_profile"); @@ -47,6 +48,7 @@ public class RankingExpressionWithTensorTestCase { " }\n" + " }\n" + " }"); + f.compileRankProfile("my_profile"); f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile"); f.assertRankProperty("{{x:1,y:2}:1.0,{x:2,y:1}:2.0}", "constant(my_tensor).value", "my_profile"); f.assertRankProperty("tensor(x{},y{})", "constant(my_tensor).type", "my_profile"); @@ -65,6 +67,7 @@ public class RankingExpressionWithTensorTestCase { " }\n" + " }\n" + " }"); + f.compileRankProfile("my_profile"); f.assertSecondPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile"); f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile"); f.assertRankProperty("tensor(x{})", "constant(my_tensor).type", "my_profile"); @@ -85,6 +88,7 @@ public class RankingExpressionWithTensorTestCase { " expression: sum(my_tensor)\n" + " }\n" + " }"); + f.compileRankProfile("my_profile"); f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile"); f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile"); f.assertRankProperty("tensor(x{})", "constant(my_tensor).type", "my_profile"); @@ -106,6 +110,7 @@ public class RankingExpressionWithTensorTestCase { " }\n" + " }\n" + " }"); + f.compileRankProfile("my_profile"); f.assertFirstPhaseExpression("5.0 + my_macro", "my_profile"); f.assertMacro("reduce(constant(my_tensor), sum)", "my_macro", "my_profile"); f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile"); @@ -127,6 +132,7 @@ public class RankingExpressionWithTensorTestCase { " my_number_2: 5.0\n" + " }\n" + " }"); + f.compileRankProfile("my_profile"); f.assertFirstPhaseExpression("3.0 + reduce(constant(my_tensor), sum) + 5.0", "my_profile"); f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile"); f.assertRankProperty("tensor(x{})", "constant(my_tensor).type", "my_profile"); @@ -139,7 +145,7 @@ public class RankingExpressionWithTensorTestCase { public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'"); - new RankProfileSearchFixture( + RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " constants {\n" + " my_tensor {\n" + @@ -148,6 +154,7 @@ public class RankingExpressionWithTensorTestCase { " }\n" + " }\n" + " }"); + f.compileRankProfile("my_profile"); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java index b65cb0b3d5f..f98783ad671 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java @@ -36,7 +36,7 @@ public class RankingExpressionWithXgboostTestCase { String field, RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage application) { try { - return new RankProfileSearchFixture( + RankProfileSearchFixture fixture = new RankProfileSearchFixture( application, application.getQueryProfiles(), " rank-profile my_profile {\n" + @@ -46,6 +46,8 @@ public class RankingExpressionWithXgboostTestCase { " }", constant, field); + fixture.compileRankProfile("my_profile"); + return fixture; } catch (ParseException e) { throw new IllegalArgumentException(e); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java index 2c46f591037..e9b4d6ac1aa 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java @@ -24,12 +24,14 @@ import java.io.File; import java.io.Reader; import java.io.StringReader; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; /** * Represents an application residing in zookeeper. @@ -200,16 +202,17 @@ public class ZKApplicationPackage implements ApplicationPackage { return ret; } - //Returns readers for all the children of a node. - //The node is looked up relative to the location of the active application package - //in zookeeper. + /** + * Returns readers for all the children of a node. + * The node is looked up relative to the location of the active application package in zookeeper. + */ @Override - public List<NamedReader> getFiles(Path relativePath,String suffix,boolean recurse) { + public List<NamedReader> getFiles(Path relativePath, String suffix, boolean recurse) { return liveApp.getAllDataFromDirectory(ConfigCurator.USERAPP_ZK_SUBPATH + '/' + relativePath.getRelative(), suffix, recurse); } @Override - public ApplicationFile getFile(Path file) { // foo/bar/baz.json + public ApplicationFile getFile(Path file) { return new ZKApplicationFile(file, liveApp); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java index d7d43dea022..956af02e36f 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java @@ -69,7 +69,8 @@ public class ZKLiveApp { log.finer("ZKApplicationPackage: Skipped '" + child + "' (did not match suffix " + fileNameSuffix + ")"); } if (recursive) - result.addAll(getAllDataFromDirectory(path + "/" + child, namePrefix + child + "/", fileNameSuffix, recursive)); + result.addAll(getAllDataFromDirectory(path + "/" + child, + namePrefix + child + "/", fileNameSuffix, recursive)); } if (log.isLoggable(Level.FINE)) log.fine("ZKApplicationPackage: Found '" + result.size() + "' files in " + fullPath); @@ -80,14 +81,15 @@ public class ZKLiveApp { } /** - * Retrieves a node relative to the node of the live application, e.g. /vespa/config/apps/$lt;app_id>/<path>/<node> + * Retrieves a node relative to the node of the live application, + * e.g. /vespa/config/apps/$lt;app_id>/<path>/<node> * * @param path a path relative to the currently active application * @param node a path relative to the path above * @return a Reader that can be used to get the data */ public Reader getDataReader(String path, String node) { - final String data = getData(path, node); + String data = getData(path, node); if (data == null) { throw new IllegalArgumentException("No node for " + getFullPath(path) + "/" + node + " exists"); } @@ -98,7 +100,8 @@ public class ZKLiveApp { try { return zk.getData(getFullPath(path), node); } catch (Exception e) { - throw new IllegalArgumentException("Could not retrieve node '" + getFullPath(path) + "/" + node + "' in zookeeper", e); + throw new IllegalArgumentException("Could not retrieve node '" + + getFullPath(path) + "/" + node + "' in zookeeper", e); } } @@ -205,5 +208,6 @@ public class ZKLiveApp { } return reader(data); } + } diff --git a/vespajlib/src/main/java/com/yahoo/path/Path.java b/vespajlib/src/main/java/com/yahoo/path/Path.java index c466fe50d6f..2806631be18 100644 --- a/vespajlib/src/main/java/com/yahoo/path/Path.java +++ b/vespajlib/src/main/java/com/yahoo/path/Path.java @@ -84,6 +84,7 @@ public final class Path { /** * Get the name of this path element, typically the last element in the path string. + * * @return the name */ public String getName() { |