aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-20 08:56:52 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-20 08:56:52 +0200
commit03e88f27bc2278fb711a8c4a8a85763b23067348 (patch)
tree3a325a5dc2f44ce7186ec202ea6b0fdc03b306ee /config-model/src/main/java/com/yahoo/searchdefinition
parente11152b73bb9f4a8020034a6bca63e82124f7b26 (diff)
Revert "Merge pull request #6619 from vespa-engine/revert-6611-revert-6596-revert-6584-bratseth/generate-rank-profiles-for-all-models-part-2"
This reverts commit 0437e8cc1d550fb8c6d24ffe4da813067c542f62, reversing changes made to 1715b8393827c159f8709033075066b29932f852.
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java381
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java5
5 files changed, 232 insertions, 174 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
index 7b4d70d85b1..6311751bb88 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
@@ -52,11 +52,11 @@ public class RankProfileRegistry {
}
private void checkForDuplicateRankProfile(RankProfile rankProfile) {
- final String rankProfileName = rankProfile.getName();
+ String rankProfileName = rankProfile.getName();
RankProfile existingRangProfileWithSameName = rankProfiles.get(rankProfile.getSearch()).get(rankProfileName);
if (existingRangProfileWithSameName == null) return;
- if (!overridableRankProfileNames.contains(rankProfileName)) {
+ if ( ! overridableRankProfileNames.contains(rankProfileName)) {
throw new IllegalArgumentException("Cannot add rank profile '" + rankProfileName + "' in search definition '"
+ rankProfile.getSearch().getName() + "', since it already exists");
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 867740c7912..3bd96c9db26 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -1,11 +1,11 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.io.IOUtils;
+import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.FeatureNames;
@@ -39,11 +39,14 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
+import java.io.Reader;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -62,69 +65,142 @@ import java.util.stream.Collectors;
*/
public class ConvertedModel {
- private final ExpressionNode convertedExpression;
+ private final String modelName;
+ private final Path modelPath;
- public ConvertedModel(FeatureArguments arguments,
+ /**
+ * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots
+ * where the first part is always the model name, the second the signature or (if none)
+ * expression name (if more than one), and the third is the output name (if any).
+ */
+ private final Map<String, RankingExpression> expressions;
+
+ public ConvertedModel(Path modelPath,
RankProfileTransformContext context,
- ModelImporter modelImporter,
- Map<Path, ImportedModel> importedModels) {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
+ ModelImporter modelImporter) {
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath);
if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory
- convertedExpression = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, importedModels);
+ expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter);
else
- convertedExpression = transformFromStoredModel(store, context.rankProfile());
+ expressions = transformFromStoredModel(store, context.rankProfile());
}
- private ExpressionNode importModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelImporter modelImporter,
- Map<Path, ImportedModel> importedModels) {
- ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
- k -> modelImporter.importModel(store.arguments().modelName(),
- store.modelDir()));
+ private Map<String, RankingExpression> importModel(ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelImporter modelImporter) {
+ ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.modelDir());
return transformFromImportedModel(model, store, profile, queryProfiles);
}
- public ExpressionNode expression() { return convertedExpression; }
+ /** Returns the expression matching the given arguments */
+ public ExpressionNode expression(FeatureArguments arguments) {
+ if (expressions.isEmpty())
+ throw new IllegalArgumentException("No expressions available in " + this);
+
+ RankingExpression expression = expressions.get(arguments.toName());
+ if (expression != null) return expression.getRoot();
+
+ if ( ! arguments.signature().isPresent()) {
+ if (expressions.size() > 1)
+ throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix());
+ return expressions.values().iterator().next().getRoot();
+ }
+
+ if ( ! arguments.output().isPresent()) {
+ List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix =
+ expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(modelName + "." + arguments.signature().get() + ".")).collect(Collectors.toList());
+ if (entriesWithTheRightPrefix.size() < 1)
+ throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() +
+ missingExpressionMessageSuffix());
+ if (entriesWithTheRightPrefix.size() > 1)
+ throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() +
+ missingExpressionMessageSuffix());
+ return entriesWithTheRightPrefix.get(0).getValue().getRoot();
+ }
+
+ throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix());
+ }
+
+ private String missingExpressionMessageSuffix() {
+ return "' in model '" + this.modelPath + "'. " +
+ "Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", "));
+ }
- private ExpressionNode transformFromImportedModel(ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
+ private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
constantsReplacedByMacros, k, v));
- // Find the specified expression
- ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature());
- String output = chooseOutput(signature, store.arguments().output());
- if (signature.skippedOutputs().containsKey(output)) {
- String message = "Could not import model output '" + output + "'";
- if (!signature.skippedOutputs().get(output).isEmpty()) {
- message += ": " + signature.skippedOutputs().get(output);
+ // Add macros
+ addGeneratedMacros(model, profile);
+
+ // Add expressions
+ Map<String, RankingExpression> expressions = new HashMap<>();
+ for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) {
+ for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) {
+ addExpression(model.expressions().get(outputEntry.getValue()),
+ modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(),
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
}
- if (!signature.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", signature.importWarnings());
+ if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs
+ addExpression(model.expressions().get(signatureEntry.getKey()),
+ modelName + "." + signatureEntry.getKey(),
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
}
- throw new IllegalArgumentException(message);
}
+ if (model.signatures().isEmpty()) { // fallback: Model without signatures
+ if (model.expressions().size() == 1) { // Use just model name
+ addExpression(model.expressions().values().iterator().next(),
+ modelName,
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
+ }
+ else {
+ for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) {
+ addExpression(expressionEntry.getValue(),
+ modelName + "." + expressionEntry.getKey(),
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
+ }
+ }
+ }
+
+ // Transform and save macro - must come after reading expressions due to optimization transforms
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
- RankingExpression expression = model.expressions().get(output);
+ return expressions;
+ }
+
+ private void addExpression(RankingExpression expression,
+ String expressionName,
+ Set<String> constantsReplacedByMacros,
+ ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ Map<String, RankingExpression> expressions) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
+ store.writeExpression(expressionName, expression);
+ expressions.put(expressionName, expression);
}
- ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -137,60 +213,7 @@ public class ConvertedModel {
addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
}
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing signature, or the only signature if none is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) {
- if ( ! signatureName.isPresent()) {
- if (importResult.signatures().size() == 0)
- throw new IllegalArgumentException("No signatures are available");
- if (importResult.signatures().size() > 1)
- throw new IllegalArgumentException("Model has multiple signatures (" +
- Joiner.on(", ").join(importResult.signatures().keySet()) +
- "), one must be specified " +
- "as a second argument to tensorflow()");
- return importResult.signatures().values().stream().findFirst().get();
- }
- else {
- ImportedModel.Signature signature = importResult.signatures().get(signatureName.get());
- if (signature == null)
- throw new IllegalArgumentException("Model does not have the specified signature '" +
- signatureName.get() + "'");
- return signature;
- }
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (signature.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
- if (signature.outputs().size() > 1)
- throw new IllegalArgumentException(signature + " has multiple outputs (" +
- Joiner.on(", ").join(signature.outputs().keySet()) +
- "), one must be specified " +
- "as a third argument to tensorflow()");
- return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = signature.outputs().get(outputName.get());
- if (output == null) {
- if (signature.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- signature.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
+ return store.readExpressions();
}
private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
@@ -227,22 +250,15 @@ public class ConvertedModel {
}
private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
+ if (profile.getMacros().containsKey(macroName))
throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists.");
- }
+
profile.addMacro(macroName, false); // todo: inline if only used once
RankProfile.Macro macro = profile.getMacros().get(macroName);
macro.setRankingExpression(expression);
macro.setTextualExpression(expression.getRoot().toString());
}
- private String skippedOutputsDescription(ImportedModel.Signature signature) {
- if (signature.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
- }
-
/**
* Verify that the macros referred in the given expression exists in the given rank profile,
* and return tensors of the types specified in requiredMacros.
@@ -375,8 +391,8 @@ public class ConvertedModel {
/**
* If batch dimensions have been reduced away above, bring them back here
* for any following computation of the tensor.
- * Todo: determine when this is not necessary!
*/
+ // TODO: determine when this is not necessary!
private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
@@ -452,24 +468,29 @@ public class ConvertedModel {
return new TensorValue(tensor);
}
+ private static String toModelName(Path modelPath) {
+ return modelPath.toString().replace("/", "_");
+ }
+
+ @Override
+ public String toString() { return "model '" + modelName + "'"; }
+
/**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
static class ModelStore {
private final ApplicationPackage application;
- private final FeatureArguments arguments;
+ private final ModelFiles modelFiles;
- ModelStore(ApplicationPackage application, FeatureArguments arguments) {
+ ModelStore(ApplicationPackage application, Path modelPath) {
this.application = application;
- this.arguments = arguments;
+ this.modelFiles = new ModelFiles(modelPath);
}
- public FeatureArguments arguments() { return arguments; }
-
public boolean hasStoredModel() {
try {
- return application.getFile(arguments.expressionPath()).exists();
+ return application.getFileReference(modelFiles.storedModelPath()).exists();
}
catch (UnsupportedOperationException e) {
return false;
@@ -480,40 +501,49 @@ public class ConvertedModel {
* Returns the directory which contains the source model to use for these arguments
*/
public File modelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath()));
}
/**
* Adds this expression to the application package, such that it can be read later.
+ *
+ * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part
+ * is always the model name
*/
- void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
+ void writeExpression(String name, RankingExpression expression) {
+ application.getFile(modelFiles.expressionPath(name))
.writeFile(new StringReader(expression.getRoot().toString()));
}
- /** Reads the previously stored ranking expression for these arguments */
- RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ Map<String, RankingExpression> readExpressions() {
+ Map<String, RankingExpression> expressions = new HashMap<>();
+ ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
+ if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
+ for (ApplicationFile expressionFile : expressionPath.listFiles()) {
+ try {
+ String name = expressionFile.getPath().getName();
+ expressions.put(name, new RankingExpression(name, expressionFile.createReader()));
+ }
+ catch (FileNotFoundException e) {
+ throw new IllegalStateException("Expression file removed while reading: " + expressionFile, e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Invalid stored expression in " + expressionFile, e);
+ }
}
+ return expressions;
}
/** Adds this macro expression to the application package to it can be read later. */
void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" +
expression.getRoot().toString() + "\n");
}
/** Reads the previously stored macro expressions for these arguments */
List<Pair<String, RankingExpression>> readMacros() {
try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
+ ApplicationFile file = application.getFile(modelFiles.macrosPath());
if (!file.exists()) return Collections.emptyList();
List<Pair<String, RankingExpression>> macros = new ArrayList<>();
@@ -527,7 +557,7 @@ public class ConvertedModel {
macros.add(new Pair<>(name, expression));
}
catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ throw new IllegalStateException("Could not parse " + name, e);
}
}
return macros;
@@ -544,7 +574,7 @@ public class ConvertedModel {
List<RankingConstant> readLargeConstants() {
try {
List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) {
String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
}
@@ -562,13 +592,13 @@ public class ConvertedModel {
* @return the path to the stored constant, relative to the application package root
*/
Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants");
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
// Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ application.getFile(modelFiles.largeConstantsPath().append(name + ".constant"))
.writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
@@ -579,7 +609,7 @@ public class ConvertedModel {
private List<Pair<String, Tensor>> readSmallConstants() {
try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ ApplicationFile file = application.getFile(modelFiles.smallConstantsPath());
if (!file.exists()) return Collections.emptyList();
List<Pair<String, Tensor>> constants = new ArrayList<>();
@@ -604,7 +634,7 @@ public class ConvertedModel {
*/
public void writeSmallConstant(String name, Tensor constant) {
// Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" +
constant.type().toString() + "\t" +
constant.toString() + "\n");
}
@@ -628,26 +658,24 @@ public class ConvertedModel {
}
}
+ private void close(Reader reader) {
+ try {
+ if (reader != null)
+ reader.close();
+ }
+ catch (IOException e) {
+ // ignore
+ }
+ }
+
}
- /** Encapsulates the arguments to the import feature */
- static class FeatureArguments {
+ static class ModelFiles {
Path modelPath;
- /** Optional arguments */
- Optional<String> signature, output;
-
- public FeatureArguments(Arguments arguments) {
- this(Path.fromString(asString(arguments.expressions().get(0))),
- optionalArgument(1, arguments),
- optionalArgument(2, arguments));
- }
-
- public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
+ public ModelFiles(Path modelPath) {
this.modelPath = modelPath;
- this.signature = signature;
- this.output = output;
}
/** Returns modelPath with slashes replaced by underscores */
@@ -655,37 +683,66 @@ public class ConvertedModel {
/** Returns relative path to this model below the "models/" dir in the application package */
public Path modelPath() { return modelPath; }
- public Optional<String> signature() { return signature; }
- public Optional<String> output() { return output; }
- /** Path to the small constants file */
+ public Path storedModelPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath());
+ }
+
+ public Path expressionPath(String name) {
+ return storedModelPath().append("expressions").append(name);
+ }
+
+ public Path expressionsPath() {
+ return storedModelPath().append("expressions");
+ }
+
public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ return storedModelPath().append("constants.txt");
}
/** Path to the large (ranking) constants directory */
public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ return storedModelPath().append("constants");
}
/** Path to the macros file */
public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ return storedModelPath().append("macros.txt");
+ }
+
+ }
+
+ /** Encapsulates the arguments of a specific model output */
+ static class FeatureArguments {
+
+ private final String modelName;
+ private final Path modelPath;
+
+ /** Optional arguments */
+ private final Optional<String> signature, output;
+
+ public FeatureArguments(Arguments arguments) {
+ this(Path.fromString(asString(arguments.expressions().get(0))),
+ optionalArgument(1, arguments),
+ optionalArgument(2, arguments));
}
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
+ public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
+ this.signature = signature;
+ this.output = output;
}
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- signature.ifPresent(s -> fileName.append(s).append("."));
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
+ public Path modelPath() { return modelPath; }
+
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ public String toName() {
+ return modelName +
+ (signature.isPresent() ? "." + signature.get() : "") +
+ (output.isPresent() ? "." + output.get() : "");
}
private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
@@ -694,7 +751,7 @@ public class ConvertedModel {
return Optional.of(asString(arguments.expressions().get(argumentIndex)));
}
- private static String asString(ExpressionNode node) {
+ public static String asString(ExpressionNode node) {
if ( ! (node instanceof ConstantNode))
throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
return stripQuotes(((ConstantNode)node).sourceString());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index d31ffefde65..0dec12c4749 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -31,7 +31,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
private final OnnxImporter onnxImporter = new OnnxImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, ImportedModel> importedModels = new HashMap<>();
+ private final Map<Path, ConvertedModel> convertedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -47,9 +47,9 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
- ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()),
- context, onnxImporter, importedModels);
- return convertedModel.expression();
+ Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index d28299b1d30..585adc0c0d4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -28,7 +28,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, ImportedModel> importedModels = new HashMap<>();
+ private final Map<Path, ConvertedModel> convertedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -44,9 +44,9 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()),
- context, tensorFlowImporter, importedModels);
- return convertedModel.expression();
+ Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index 4ae223ec3a5..62f43e15849 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -37,7 +37,8 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
try {
ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments());
- ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
+ ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(),
+ arguments.modelPath());
RankingExpression expression = xgboostImporter.parseModel(store.modelDir().toString());
return expression.getRoot();
} catch (IllegalArgumentException | UncheckedIOException e) {
@@ -48,7 +49,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("An xgboost node must take an argument pointing to " +
- "the xgboost model directory under [application]/models");
+ "the xgboost model directory under [application]/models");
if (arguments.expressions().size() > 1)
throw new IllegalArgumentException("An xgboost feature can have at most 1 argument");