diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-15 20:36:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-15 20:36:17 +0200 |
commit | 8f3ebc6bcecb686b5ffcc6393e1f612e37b2ad7d (patch) | |
tree | b34977667c2b29d4f0f4dc2397f7d68fdeafe4a2 /config-model/src/main/java/com/yahoo/searchdefinition | |
parent | 80d703603a4e3c5282af45a3c36dfa46888622a7 (diff) |
Revert "Convert all outputs"
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
5 files changed, 172 insertions, 229 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java index 6311751bb88..7b4d70d85b1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java @@ -52,11 +52,11 @@ public class RankProfileRegistry { } private void checkForDuplicateRankProfile(RankProfile rankProfile) { - String rankProfileName = rankProfile.getName(); + final String rankProfileName = rankProfile.getName(); RankProfile existingRangProfileWithSameName = rankProfiles.get(rankProfile.getSearch()).get(rankProfileName); if (existingRangProfileWithSameName == null) return; - if ( ! overridableRankProfileNames.contains(rankProfileName)) { + if (!overridableRankProfileNames.contains(rankProfileName)) { throw new IllegalArgumentException("Cannot add rank profile '" + rankProfileName + "' in search definition '" + rankProfile.getSearch().getName() + "', since it already exists"); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index e073be71a0c..867740c7912 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -1,11 +1,11 @@ package com.yahoo.searchdefinition.expressiontransforms; +import com.google.common.base.Joiner; import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; -import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.FeatureNames; @@ -40,12 +40,10 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; import java.io.IOException; -import java.io.Reader; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -64,142 +62,69 @@ import java.util.stream.Collectors; */ public class ConvertedModel { - private final String modelName; - private final Path modelPath; + private final ExpressionNode convertedExpression; - /** - * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots - * where the first part is always the model name, the second the signature or (if none) - * expression name (if more than one), and the third is the output name (if any). - */ - private final Map<String, RankingExpression> expressions; - - public ConvertedModel(Path modelPath, + public ConvertedModel(FeatureArguments arguments, RankProfileTransformContext context, - ModelImporter modelImporter) { - this.modelPath = modelPath; - this.modelName = toModelName(modelPath); - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); + ModelImporter modelImporter, + Map<Path, ImportedModel> importedModels) { + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory - expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter); + convertedExpression = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, importedModels); else - expressions = transformFromStoredModel(store, context.rankProfile()); + convertedExpression = transformFromStoredModel(store, context.rankProfile()); } - private Map<String, RankingExpression> importModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelImporter modelImporter) { - ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.modelDir()); + private ExpressionNode importModel(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelImporter modelImporter, + Map<Path, ImportedModel> importedModels) { + ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> modelImporter.importModel(store.arguments().modelName(), + store.modelDir())); return transformFromImportedModel(model, store, profile, queryProfiles); } - /** Returns the expression matching the given arguments */ - public ExpressionNode expression(FeatureArguments arguments) { - if (expressions.isEmpty()) - throw new IllegalArgumentException("No expressions available in " + this); - - RankingExpression expression = expressions.get(arguments.toName()); - if (expression != null) return expression.getRoot(); + public ExpressionNode expression() { return convertedExpression; } - if ( ! arguments.signature().isPresent()) { - if (expressions.size() > 1) - throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix()); - return expressions.values().iterator().next().getRoot(); - } - - if ( ! arguments.output().isPresent()) { - List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix = - expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(modelName + "." + arguments.signature().get() + ".")).collect(Collectors.toList()); - if (entriesWithTheRightPrefix.size() < 1) - throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + - missingExpressionMessageSuffix()); - if (entriesWithTheRightPrefix.size() > 1) - throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() + - missingExpressionMessageSuffix()); - return entriesWithTheRightPrefix.get(0).getValue().getRoot(); - } - - throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix()); - } - - private String missingExpressionMessageSuffix() { - return "' in model '" + this.modelPath + "'. " + - "Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", ")); - } - - private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { + private ExpressionNode transformFromImportedModel(ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles) { // Add constants Set<String> constantsReplacedByMacros = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, constantsReplacedByMacros, k, v)); - // Add macros - addGeneratedMacros(model, profile); - - // Add expressions - Map<String, RankingExpression> expressions = new HashMap<>(); - for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) { - for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) { - addExpression(model.expressions().get(outputEntry.getValue()), - modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); + // Find the specified expression + ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature()); + String output = chooseOutput(signature, store.arguments().output()); + if (signature.skippedOutputs().containsKey(output)) { + String message = "Could not import model output '" + output + "'"; + if (!signature.skippedOutputs().get(output).isEmpty()) { + message += ": " + signature.skippedOutputs().get(output); } - if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs - addExpression(model.expressions().get(signatureEntry.getKey()), - modelName + "." + signatureEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - } - if (model.signatures().isEmpty()) { // fallback: Model without signatures - if (model.expressions().size() == 1) { // Use just model name - addExpression(model.expressions().values().iterator().next(), - modelName, - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - else { - for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) { - addExpression(expressionEntry.getValue(), - modelName + "." + expressionEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } + if (!signature.importWarnings().isEmpty()) { + message += ": " + String.join(", ", signature.importWarnings()); } + throw new IllegalArgumentException(message); } - // Transform and save macro - must come after reading expressions due to optimization transforms - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - - return expressions; - } - - private void addExpression(RankingExpression expression, - String expressionName, - Set<String> constantsReplacedByMacros, - ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - Map<String, RankingExpression> expressions) { + RankingExpression expression = model.expressions().get(output); expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); reduceBatchDimensions(expression, model, profile, queryProfiles); - store.writeExpression(expressionName, expression); - expressions.put(expressionName, expression); + + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); + + store.writeConverted(expression); + return expression.getRoot(); } - private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) { + ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -212,7 +137,60 @@ public class ConvertedModel { addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); } - return store.readExpressions(); + return store.readConverted().getRoot(); + } + + /** + * Returns the specified, existing signature, or the only signature if none is specified. + * Throws IllegalArgumentException in all other cases. + */ + private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) { + if ( ! signatureName.isPresent()) { + if (importResult.signatures().size() == 0) + throw new IllegalArgumentException("No signatures are available"); + if (importResult.signatures().size() > 1) + throw new IllegalArgumentException("Model has multiple signatures (" + + Joiner.on(", ").join(importResult.signatures().keySet()) + + "), one must be specified " + + "as a second argument to tensorflow()"); + return importResult.signatures().values().stream().findFirst().get(); + } + else { + ImportedModel.Signature signature = importResult.signatures().get(signatureName.get()); + if (signature == null) + throw new IllegalArgumentException("Model does not have the specified signature '" + + signatureName.get() + "'"); + return signature; + } + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) { + if ( ! outputName.isPresent()) { + if (signature.outputs().size() == 0) + throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); + if (signature.outputs().size() > 1) + throw new IllegalArgumentException(signature + " has multiple outputs (" + + Joiner.on(", ").join(signature.outputs().keySet()) + + "), one must be specified " + + "as a third argument to tensorflow()"); + return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); + } + else { + String output = signature.outputs().get(outputName.get()); + if (output == null) { + if (signature.skippedOutputs().containsKey(outputName.get())) + throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + + signature.skippedOutputs().get(outputName.get())); + else + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + } + return output; + } } private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { @@ -249,15 +227,22 @@ public class ConvertedModel { } private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) + if (profile.getMacros().containsKey(macroName)) { throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); - + } profile.addMacro(macroName, false); // todo: inline if only used once RankProfile.Macro macro = profile.getMacros().get(macroName); macro.setRankingExpression(expression); macro.setTextualExpression(expression.getRoot().toString()); } + private String skippedOutputsDescription(ImportedModel.Signature signature) { + if (signature.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); + } + /** * Verify that the macros referred in the given expression exists in the given rank profile, * and return tensors of the types specified in requiredMacros. @@ -390,8 +375,8 @@ public class ConvertedModel { /** * If batch dimensions have been reduced away above, bring them back here * for any following computation of the tensor. + * Todo: determine when this is not necessary! */ - // TODO: determine when this is not necessary! private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { if (after.equals(before)) { return node; @@ -467,29 +452,24 @@ public class ConvertedModel { return new TensorValue(tensor); } - private static String toModelName(Path modelPath) { - return modelPath.toString().replace("/", "_"); - } - - @Override - public String toString() { return "model '" + modelName + "'"; } - /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ static class ModelStore { private final ApplicationPackage application; - private final ModelFiles modelFiles; + private final FeatureArguments arguments; - ModelStore(ApplicationPackage application, Path modelPath) { + ModelStore(ApplicationPackage application, FeatureArguments arguments) { this.application = application; - this.modelFiles = new ModelFiles(modelPath); + this.arguments = arguments; } + public FeatureArguments arguments() { return arguments; } + public boolean hasStoredModel() { try { - return application.getFileReference(modelFiles.storedModelPath()).exists(); + return application.getFile(arguments.expressionPath()).exists(); } catch (UnsupportedOperationException e) { return false; @@ -500,49 +480,40 @@ public class ConvertedModel { * Returns the directory which contains the source model to use for these arguments */ public File modelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath())); + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); } /** * Adds this expression to the application package, such that it can be read later. - * - * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part - * is always the model name */ - void writeExpression(String name, RankingExpression expression) { - application.getFile(modelFiles.expressionPath(name)) + void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) .writeFile(new StringReader(expression.getRoot().toString())); } - Map<String, RankingExpression> readExpressions() { - Map<String, RankingExpression> expressions = new HashMap<>(); - List<NamedReader> expressionReaders = null; + /** Reads the previously stored ranking expression for these arguments */ + RankingExpression readConverted() { try { - expressionReaders = application.getFiles(modelFiles.expressionsPath(), "expression"); - for (NamedReader expressionReader : expressionReaders) { - try { - expressions.put(expressionReader.getName(), new RankingExpression(expressionReader.getReader())); - } catch (ParseException e) { - throw new IllegalStateException("Could not parse " + expressionReader.getName(), e); - } - } + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); } - finally { - expressionReaders.forEach(r -> close(r)); + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); } - return expressions; } /** Adds this macro expression to the application package to it can be read later. */ void writeMacro(String name, RankingExpression expression) { - application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" + + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + expression.getRoot().toString() + "\n"); } /** Reads the previously stored macro expressions for these arguments */ List<Pair<String, RankingExpression>> readMacros() { try { - ApplicationFile file = application.getFile(modelFiles.macrosPath()); + ApplicationFile file = application.getFile(arguments.macrosPath()); if (!file.exists()) return Collections.emptyList(); List<Pair<String, RankingExpression>> macros = new ArrayList<>(); @@ -556,7 +527,7 @@ public class ConvertedModel { macros.add(new Pair<>(name, expression)); } catch (ParseException e) { - throw new IllegalStateException("Could not parse " + name, e); + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); } } return macros; @@ -573,7 +544,7 @@ public class ConvertedModel { List<RankingConstant> readLargeConstants() { try { List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) { + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); } @@ -591,13 +562,13 @@ public class ConvertedModel { * @return the path to the stored constant, relative to the application package root */ Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants"); + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); // Remember the constant in a file we replicate in ZooKeeper - application.getFile(modelFiles.largeConstantsPath().append(name + ".constant")) + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution @@ -608,7 +579,7 @@ public class ConvertedModel { private List<Pair<String, Tensor>> readSmallConstants() { try { - ApplicationFile file = application.getFile(modelFiles.smallConstantsPath()); + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); if (!file.exists()) return Collections.emptyList(); List<Pair<String, Tensor>> constants = new ArrayList<>(); @@ -633,7 +604,7 @@ public class ConvertedModel { */ public void writeSmallConstant(String name, Tensor constant) { // Secret file format for remembering constants: - application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" + + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + constant.type().toString() + "\t" + constant.toString() + "\n"); } @@ -657,24 +628,26 @@ public class ConvertedModel { } } - private void close(Reader reader) { - try { - if (reader != null) - reader.close(); - } - catch (IOException e) { - // ignore - } - } - } - static class ModelFiles { + /** Encapsulates the arguments to the import feature */ + static class FeatureArguments { Path modelPath; - public ModelFiles(Path modelPath) { + /** Optional arguments */ + Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { + this(Path.fromString(asString(arguments.expressions().get(0))), + optionalArgument(1, arguments), + optionalArgument(2, arguments)); + } + + public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { this.modelPath = modelPath; + this.signature = signature; + this.output = output; } /** Returns modelPath with slashes replaced by underscores */ @@ -682,66 +655,37 @@ public class ConvertedModel { /** Returns relative path to this model below the "models/" dir in the application package */ public Path modelPath() { return modelPath; } + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } - public Path storedModelPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath()); - } - - public Path expressionPath(String name) { - return storedModelPath().append("expressions").append(name + ".expression"); - } - - public Path expressionsPath() { - return storedModelPath().append("expressions"); - } - + /** Path to the small constants file */ public Path smallConstantsPath() { - return storedModelPath().append("constants.txt"); + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); } /** Path to the large (ranking) constants directory */ public Path largeConstantsPath() { - return storedModelPath().append("constants"); + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); } /** Path to the macros file */ public Path macrosPath() { - return storedModelPath().append("macros.txt"); - } - - } - - /** Encapsulates the arguments of a specific model output */ - static class FeatureArguments { - - private final String modelName; - private final Path modelPath; - - /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - this(Path.fromString(asString(arguments.expressions().get(0))), - optionalArgument(1, arguments), - optionalArgument(2, arguments)); + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); } - public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { - this.modelPath = modelPath; - this.modelName = toModelName(modelPath); - this.signature = signature; - this.output = output; + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); } - public Path modelPath() { return modelPath; } - - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - public String toName() { - return modelName + - (signature.isPresent() ? "." + signature.get() : "") + - (output.isPresent() ? "." + output.get() : ""); + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + signature.ifPresent(s -> fileName.append(s).append(".")); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); } private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { @@ -750,7 +694,7 @@ public class ConvertedModel { return Optional.of(asString(arguments.expressions().get(argumentIndex))); } - public static String asString(ExpressionNode node) { + private static String asString(ExpressionNode node) { if ( ! (node instanceof ConstantNode)) throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); return stripQuotes(((ConstantNode)node).sourceString()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 0dec12c4749..d31ffefde65 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -31,7 +31,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans private final OnnxImporter onnxImporter = new OnnxImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ConvertedModel> convertedModels = new HashMap<>(); + private final Map<Path, ImportedModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -47,9 +47,9 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans if ( ! feature.getName().equals("onnx")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); - ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter)); - return convertedModel.expression(asFeatureArguments(feature.getArguments())); + ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()), + context, onnxImporter, importedModels); + return convertedModel.expression(); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use Onnx model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 585adc0c0d4..d28299b1d30 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -28,7 +28,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ConvertedModel> convertedModels = new HashMap<>(); + private final Map<Path, ImportedModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -44,9 +44,9 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); - ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter)); - return convertedModel.expression(asFeatureArguments(feature.getArguments())); + ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()), + context, tensorFlowImporter, importedModels); + return convertedModel.expression(); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index 62f43e15849..4ae223ec3a5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -37,8 +37,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr try { ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments()); - ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), - arguments.modelPath()); + ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); RankingExpression expression = xgboostImporter.parseModel(store.modelDir().toString()); return expression.getRoot(); } catch (IllegalArgumentException | UncheckedIOException e) { @@ -49,7 +48,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An xgboost node must take an argument pointing to " + - "the xgboost model directory under [application]/models"); + "the xgboost model directory under [application]/models"); if (arguments.expressions().size() > 1) throw new IllegalArgumentException("An xgboost feature can have at most 1 argument"); |