summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-29 16:01:28 +0200
committerGitHub <noreply@github.com>2018-08-29 16:01:28 +0200
commit6958f2a641eaad0c61249e7bca887a1405e17d02 (patch)
tree49c8299f2fffbff2d814eea0cf86e8081d90cce8
parent06323aff51bf054d64ef2bea001917a22433717f (diff)
parentc27d2709eea8b697f8e099c1af12872bb7a75610 (diff)
Merge pull request #6722 from vespa-engine/revert-6713-bratseth/generate-rank-profiles-for-all-models-part-10
Revert "Read stored models from Zk package for global rank profiles"
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java6
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java126
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java245
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java35
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java13
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java1
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java123
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java18
13 files changed, 274 insertions, 335 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
index c7258d8aede..f926259f115 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
@@ -159,10 +159,12 @@ public interface ApplicationPackage {
}
/**
- * Returns inforamtion about a file
+ * Gets a file from the root of the application package
+ *
*
* @param relativePath the relative path of the file within this application package.
- * @return information abut the file, returned whether or not the file exists
+ * @return information abut the file
+ * @throws IllegalArgumentException if the given path does not exist
*/
ApplicationFile getFile(Path relativePath);
diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
index 757dab4cbf3..7404ae14a5d 100644
--- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
+++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
@@ -16,17 +16,11 @@ import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.config.application.api.ApplicationPackage;
-import java.io.BufferedInputStream;
import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
import java.io.IOException;
-import java.io.InputStream;
import java.io.Reader;
import java.io.StringReader;
-import java.io.UncheckedIOException;
import java.util.*;
-import java.util.stream.Collectors;
/**
* For testing purposes only
@@ -119,7 +113,7 @@ public class MockApplicationPackage implements ApplicationPackage {
@Override
public ApplicationFile getFile(Path file) {
- return new MockApplicationFile(file, Path.fromString(root.toString()));
+ throw new UnsupportedOperationException();
}
@Override
@@ -306,122 +300,4 @@ public class MockApplicationPackage implements ApplicationPackage {
return xmlStringWithIdAttribute.substring(idStart + 4, idEnd - 1);
}
- public static class MockApplicationFile extends ApplicationFile {
-
- /** The path to the application package root */
- private final Path root;
-
- /** The File pointing to the actual file represented by this */
- private final File file;
-
- public MockApplicationFile(Path filePath, Path applicationPackagePath) {
- super(filePath);
- this.root = applicationPackagePath;
- file = applicationPackagePath.append(filePath).toFile();
- }
-
- @Override
- public boolean isDirectory() {
- return file.isDirectory();
- }
-
- @Override
- public boolean exists() {
- return file.exists();
- }
-
- @Override
- public Reader createReader() throws FileNotFoundException {
- try {
- if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
- return IOUtils.createReader(file, "UTF-8");
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public InputStream createInputStream() throws FileNotFoundException {
- try {
- if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
- return new BufferedInputStream(new FileInputStream(file));
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public ApplicationFile createDirectory() {
- file.mkdirs();
- return this;
- }
-
- @Override
- public ApplicationFile writeFile(Reader input) {
- try {
- IOUtils.writeFile(file, IOUtils.readAll(input), false);
- return this;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public ApplicationFile appendFile(String value) {
- try {
- IOUtils.writeFile(file, value, true);
- return this;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public List<ApplicationFile> listFiles(PathFilter filter) {
- if ( ! isDirectory()) return Collections.emptyList();
- return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString())))
- .map(f -> new MockApplicationFile(asApplicationRelativePath(f), root))
- .collect(Collectors.toList());
- }
-
- @Override
- public ApplicationFile delete() {
- file.delete();
- return this;
- }
-
- @Override
- public MetaData getMetaData() {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int compareTo(ApplicationFile other) {
- return this.getPath().getName().compareTo((other).getPath().getName());
- }
-
- /** Strips the application package root path prefix from the path of the given file */
- private Path asApplicationRelativePath(File file) {
- Path path = Path.fromString(file.toString());
-
- Iterator<String> pathIterator = path.iterator();
- // Skip the path elements this shares with the root
- for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) {
- String rootElement = rootIterator.next();
- String pathElement = pathIterator.next();
- if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken");
- }
- // Build a path from the remaining
- Path relative = Path.fromString("");
- while (pathIterator.hasNext())
- relative = relative.append(pathIterator.next());
- return relative;
- }
-
- }
-
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 935b9200868..0911f567fa1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -1,6 +1,5 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -17,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -38,6 +38,7 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
@@ -64,91 +65,49 @@ import java.util.stream.Collectors;
public class ConvertedModel {
private final String modelName;
- private final String modelDescription;
- private final ImmutableMap<String, RankingExpression> expressions;
-
- /** The source importedModel, or empty if this was created from a stored converted model */
- private final Optional<ImportedModel> sourceModel;
-
- private ConvertedModel(String modelName,
- String modelDescription,
- Map<String, RankingExpression> expressions,
- Optional<ImportedModel> sourceModel) {
- this.modelName = modelName;
- this.modelDescription = modelDescription;
- this.expressions = ImmutableMap.copyOf(expressions);
- this.sourceModel = sourceModel;
- }
+ private final Path modelPath;
/**
- * Create and store a converted model for a rank profile given from either an imported model,
- * or (if unavailable) from stored application package data.
+ * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots
+ * where the first part is always the model name, the second the signature or (if none)
+ * expression name (if more than one), and the third is the output name (if any).
*/
- public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) {
- File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath);
- if (sourceModel.exists())
- return fromSource(toModelName(modelPath),
- modelPath.toString(),
- context.rankProfile(),
- context.queryProfiles(),
- context.importedModels().get(sourceModel)); // TODO: Convert to name here, make sure its done just one way
- else
- return fromStore(toModelName(modelPath),
- modelPath.toString(),
- context.rankProfile());
- }
-
- public static ConvertedModel fromSource(String modelName,
- String modelDescription,
- RankProfile rankProfile,
- QueryProfileRegistry queryProfileRegistry,
- ImportedModel importedModel) {
- ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
- return new ConvertedModel(modelName,
- modelDescription,
- convertAndStore(importedModel, rankProfile, queryProfileRegistry, modelStore),
- Optional.of(importedModel));
- }
-
- public static ConvertedModel fromStore(String modelName,
- String modelDescription,
- RankProfile rankProfile) {
- ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
- return new ConvertedModel(modelName,
- modelDescription,
- convertStored(modelStore, rankProfile),
- Optional.empty());
- }
+ private final Map<String, RankingExpression> expressions;
/**
- * Returns all the output expressions of this indexed by name. The names consist of one or two parts
- * separated by dot, where the first part is the signature name
- * if signatures are used, or the expression name if signatures are not used and there are multiple
- * expressions, and the second is the output name if signature names are used.
+ * Create a converted model for a rank profile given from either an imported model,
+ * or (if unavailable) from stored application package data.
*/
- public Map<String, RankingExpression> expressions() { return expressions; }
+ public ConvertedModel(Path modelPath, RankProfileTransformContext context) {
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
+ ModelStore store = new ModelStore(context.rankProfile().applicationPackage(), modelPath);
+ if ( store.hasSourceModel())
+ expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels());
+ else
+ expressions = transformFromStoredModel(store, context.rankProfile());
+ }
- /**
- * Returns the expression matching the given arguments.
- */
- public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
- RankingExpression expression = selectExpression(arguments);
- if (sourceModel.isPresent()) // we can verify
- verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
- return expression.getRoot();
+ private Map<String, RankingExpression> convertModel(ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels) {
+ ImportedModel model = importedModels.get(store.sourceModelFile());
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- private RankingExpression selectExpression(FeatureArguments arguments) {
+ /** Returns the expression matching the given arguments */
+ public ExpressionNode expression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
RankingExpression expression = expressions.get(arguments.toName());
- if (expression != null) return expression;
+ if (expression != null) return expression.getRoot();
if ( ! arguments.signature().isPresent()) {
if (expressions.size() > 1)
throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix());
- return expressions.values().iterator().next();
+ return expressions.values().iterator().next().getRoot();
}
if ( ! arguments.output().isPresent()) {
@@ -160,23 +119,21 @@ public class ConvertedModel {
if (entriesWithTheRightPrefix.size() > 1)
throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() +
missingExpressionMessageSuffix());
- return entriesWithTheRightPrefix.get(0).getValue();
+ return entriesWithTheRightPrefix.get(0).getValue().getRoot();
}
throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix());
}
private String missingExpressionMessageSuffix() {
- return "' in model '" + modelDescription + "'. " +
+ return "' in model '" + this.modelPath + "'. " +
"Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", "));
}
- // ----------------------- Static model conversion/storage below here
-
- private static Map<String, RankingExpression> convertAndStore(ImportedModel model,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelStore store) {
+ private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -204,21 +161,22 @@ public class ConvertedModel {
return expressions;
}
- private static void addExpression(RankingExpression expression,
- String expressionName,
- Set<String> constantsReplacedByMacros,
- ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- Map<String, RankingExpression> expressions) {
+ private void addExpression(RankingExpression expression,
+ String expressionName,
+ Set<String> constantsReplacedByMacros,
+ ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ Map<String, RankingExpression> expressions) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
reduceBatchDimensions(expression, model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) {
+ private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -234,12 +192,12 @@ public class ConvertedModel {
return store.readExpressions();
}
- private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
store.writeSmallConstant(constantName, constantValue);
profile.addConstant(constantName, asValue(constantValue));
}
- private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
Set<String> constantsReplacedByMacros,
String constantName, Tensor constantValue) {
RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
@@ -259,7 +217,7 @@ public class ConvertedModel {
}
}
- private static void transformGeneratedMacro(ModelStore store,
+ private void transformGeneratedMacro(ModelStore store,
Set<String> constantsReplacedByMacros,
String macroName,
RankingExpression expression) {
@@ -268,7 +226,7 @@ public class ConvertedModel {
store.writeMacro(macroName, expression);
}
- private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
if (profile.getMacros().containsKey(macroName)) {
if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression))
throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile +
@@ -285,8 +243,8 @@ public class ConvertedModel {
* Verify that the macros referred in the given expression exists in the given rank profile,
* and return tensors of the types specified in requiredMacros.
*/
- private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
@@ -314,7 +272,7 @@ public class ConvertedModel {
}
}
- private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
(actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
"in query profile types - see the documentation."
@@ -324,7 +282,7 @@ public class ConvertedModel {
/**
* Add the generated macros to the rank profile
*/
- private static void addGeneratedMacros(ImportedModel model, RankProfile profile) {
+ private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy()));
}
@@ -333,8 +291,8 @@ public class ConvertedModel {
* macro specifies that a single exemplar should be evaluated, we can
* reduce the batch dimension out.
*/
- private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
@@ -361,8 +319,8 @@ public class ConvertedModel {
expression.setRoot(root);
}
- private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
- TypeContext<Reference> typeContext) {
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
@@ -392,7 +350,7 @@ public class ConvertedModel {
return node;
}
- private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
TensorFunction result = function;
TensorType type = function.type(context);
if (type.dimensions().size() > 1) {
@@ -414,7 +372,7 @@ public class ConvertedModel {
* for any following computation of the tensor.
*/
// TODO: determine when this is not necessary!
- private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
}
@@ -441,14 +399,14 @@ public class ConvertedModel {
* If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
* This method does that for the given expression and returns the result.
*/
- private static RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
Set<String> constantsReplacedByMacros) {
if (constantsReplacedByMacros.isEmpty()) return expression;
return new RankingExpression(expression.getName(),
replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
}
- private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
if (node instanceof ReferenceNode) {
Reference reference = ((ReferenceNode)node).reference();
if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
@@ -466,7 +424,7 @@ public class ConvertedModel {
return node;
}
- private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
@@ -482,7 +440,7 @@ public class ConvertedModel {
}
}
- private static Value asValue(Tensor tensor) {
+ private Value asValue(Tensor tensor) {
if (tensor.type().rank() == 0)
return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
else
@@ -497,13 +455,6 @@ public class ConvertedModel {
public String toString() { return "model '" + modelName + "'"; }
/**
- * Returns the directory which contains the source model to use for these arguments
- */
- public static File sourceModelFile(ApplicationPackage application, Path sourceModelPath) {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(sourceModelPath));
- }
-
- /**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
static class ModelStore {
@@ -511,9 +462,20 @@ public class ConvertedModel {
private final ApplicationPackage application;
private final ModelFiles modelFiles;
- ModelStore(ApplicationPackage application, String modelName) {
+ ModelStore(ApplicationPackage application, Path modelPath) {
this.application = application;
- this.modelFiles = new ModelFiles(modelName);
+ this.modelFiles = new ModelFiles(modelPath);
+ }
+
+ public boolean hasSourceModel() {
+ return sourceModelFile().exists();
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File sourceModelFile() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath()));
}
/**
@@ -586,7 +548,7 @@ public class ConvertedModel {
List<RankingConstant> readLargeConstants() {
try {
List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsInfoPath()).listFiles()) {
+ for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) {
String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
}
@@ -604,13 +566,13 @@ public class ConvertedModel {
* @return the path to the stored constant, relative to the application package root
*/
Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = modelFiles.largeConstantsContentPath();
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants");
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
// Remember the constant in a file we replicate in ZooKeeper
- application.getFile(modelFiles.largeConstantsInfoPath().append(name + ".constant"))
+ application.getFile(modelFiles.largeConstantsPath().append(name + ".constant"))
.writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
@@ -647,8 +609,8 @@ public class ConvertedModel {
public void writeSmallConstant(String name, Tensor constant) {
// Secret file format for remembering constants:
application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
}
/** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
@@ -670,24 +632,40 @@ public class ConvertedModel {
}
}
+ private void close(Reader reader) {
+ try {
+ if (reader != null)
+ reader.close();
+ }
+ catch (IOException e) {
+ // ignore
+ }
+ }
+
}
static class ModelFiles {
- String modelName;
+ Path modelPath;
- public ModelFiles(String modelName) {
- this.modelName = modelName;
+ public ModelFiles(Path modelPath) {
+ this.modelPath = modelPath;
}
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+
/** Files stored below this path will be replicated in zookeeper */
public Path storedModelReplicatedPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName);
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath());
}
- /** Files stored below this path will not be replicated in zookeeper */
+ /** Files stored below this path will not be replicated */
public Path storedModelPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName);
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath());
}
public Path expressionPath(String name) {
@@ -703,12 +681,7 @@ public class ConvertedModel {
}
/** Path to the large (ranking) constants directory */
- public Path largeConstantsContentPath() {
- return storedModelPath().append("constants");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsInfoPath() {
+ public Path largeConstantsPath() {
return storedModelReplicatedPath().append("constants");
}
@@ -722,19 +695,27 @@ public class ConvertedModel {
/** Encapsulates the arguments of a specific model output */
static class FeatureArguments {
+ private final String modelName;
+ private final Path modelPath;
+
/** Optional arguments */
private final Optional<String> signature, output;
public FeatureArguments(Arguments arguments) {
- this(optionalArgument(1, arguments),
+ this(Path.fromString(asString(arguments.expressions().get(0))),
+ optionalArgument(1, arguments),
optionalArgument(2, arguments));
}
- public FeatureArguments(Optional<String> signature, Optional<String> output) {
+ public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
this.signature = signature;
this.output = output;
}
+ public Path modelPath() { return modelPath; }
+
public Optional<String> signature() { return signature; }
public Optional<String> output() { return output; }
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 229ae0ebaaf..36dc200f3c9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -13,7 +13,6 @@ import java.io.File;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
-import java.util.Optional;
/**
* Replaces instances of the onnx(model-path, output)
@@ -42,11 +41,10 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
- // TODO: Put modelPath in FeatureArguments instead
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
+ convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index bcb8ef1521d..619c13da764 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -41,8 +41,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
+ convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index b4a5069b9d6..e6b08ab0350 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -43,8 +43,8 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
+ convertedXGBoostModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
} catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index ec92905cd1f..73dd60f63eb 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -8,7 +8,6 @@ import com.yahoo.config.ConfigInstance;
import com.yahoo.config.ConfigInstance.Builder;
import com.yahoo.config.ConfigurationRuntimeException;
import com.yahoo.config.FileReference;
-import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.application.api.ValidationId;
@@ -27,13 +26,11 @@ import com.yahoo.config.model.producer.AbstractConfigProducerRoot;
import com.yahoo.config.model.producer.UserConfigRepo;
import com.yahoo.config.provision.AllocatedHosts;
import com.yahoo.log.LogLevel;
-import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstants;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RankProfileList;
-import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
@@ -165,9 +162,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
this.applicationPackage = deployState.getApplicationPackage();
root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this);
- createGlobalRankProfiles(deployState.getImportedModels(),
- deployState.rankProfileRegistry(),
- deployState.getQueryProfiles().getRegistry());
+ createGlobalRankProfiles(deployState.getImportedModels(), deployState.rankProfileRegistry());
this.rankProfileList = new RankProfileList(null, // null search -> global
AttributeFields.empty,
deployState.rankProfileRegistry(),
@@ -225,30 +220,14 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
* Creates a rank profile not attached to any search definition, for each imported model in the application package
*/
private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels,
- RankProfileRegistry rankProfileRegistry,
- QueryProfileRegistry queryProfiles) {
+ RankProfileRegistry rankProfileRegistry) {
List<RankProfile> profiles = new ArrayList<>();
- if ( ! importedModels.all().isEmpty()) { // models/ directory is available
- for (ImportedModel model : importedModels.all()) {
- RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
- rankProfileRegistry.add(profile);
- ConvertedModel convertedModel = ConvertedModel.fromSource(model.name(), model.name(), profile, queryProfiles, model);
- for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
- profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
- }
- }
- }
- else { // generated and saved model information may be available instead
- ApplicationFile generatedModelsDir = applicationPackage.getFile(ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR);
- for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) {
- String modelName = generatedModelDir.getPath().last();
- RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
- rankProfileRegistry.add(profile);
- ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, modelName, profile);
- for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
- profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
- }
+ for (ImportedModel model : importedModels.all()) {
+ RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
+ for (Pair<String, RankingExpression> entry : model.outputExpressions()) {
+ profile.addMacro(entry.getFirst(), false).setRankingExpression(entry.getSecond());
}
+ rankProfileRegistry.add(profile);
}
return ImmutableList.copyOf(profiles);
}
diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
index 8558ccc44bd..d06752c9b6d 100644
--- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
@@ -2,14 +2,10 @@ package com.yahoo.config.model;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
-import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
-import com.yahoo.path.Path;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ContainerCluster;
-import org.junit.After;
import org.junit.Test;
import org.xml.sax.SAXException;
@@ -26,16 +22,11 @@ import static org.junit.Assert.assertTrue;
*/
public class ModelEvaluationTest {
- private static final String appDir = "src/test/cfg/application/ml_serving";
-
- @After
- public void removeGeneratedModelFiles() {
- IOUtils.recursiveDeleteDir(Path.fromString(appDir).append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- }
+ private static final String TESTDIR = "src/test/cfg/application/";
@Test
public void testMl_ServingApplication() throws SAXException, IOException {
- ApplicationPackageTester tester = ApplicationPackageTester.create(appDir);
+ ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "ml_serving");
VespaModel model = new VespaModel(tester.app());
ContainerCluster cluster = model.getContainerClusters().get("container");
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 04a6f953bb6..815a01cdb99 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -142,6 +142,7 @@ public class RankingExpressionWithOnnxTestCase {
}
}
+
@Test
public void testOnnxReferenceWithWrongMacroType() {
try {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 28fcf871cf3..c317f07b87a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -403,7 +403,7 @@ public class RankingExpressionWithTensorFlowTestCase {
*/
private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
try {
- Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax_saved/constants").append(name + ".tbf");
+ Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf");
RankingConstant rankingConstant = search.search().rankingConstants().get(name);
assertEquals(name, rankingConstant.getName());
assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
@@ -485,7 +485,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Override
public ApplicationFile getFile(Path file) {
- return new MockApplicationFile(file, Path.fromString(root().toString()));
+ return new StoringApplicationPackageFile(file, Path.fromString(root().toString()));
}
@Override
@@ -505,4 +505,123 @@ public class RankingExpressionWithTensorFlowTestCase {
}
+ static class StoringApplicationPackageFile extends ApplicationFile {
+
+ /** The path to the application package root */
+ private final Path root;
+
+ /** The File pointing to the actual file represented by this */
+ private final File file;
+
+ StoringApplicationPackageFile(Path filePath, Path applicationPackagePath) {
+ super(filePath);
+ this.root = applicationPackagePath;
+ file = applicationPackagePath.append(filePath).toFile();
+ }
+
+ @Override
+ public boolean isDirectory() {
+ return file.isDirectory();
+ }
+
+ @Override
+ public boolean exists() {
+ return file.exists();
+ }
+
+ @Override
+ public Reader createReader() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return IOUtils.createReader(file, "UTF-8");
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return new BufferedInputStream(new FileInputStream(file));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public ApplicationFile createDirectory() {
+ file.mkdirs();
+ return this;
+ }
+
+ @Override
+ public ApplicationFile writeFile(Reader input) {
+ try {
+ IOUtils.writeFile(file, IOUtils.readAll(input), false);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public ApplicationFile appendFile(String value) {
+ try {
+ IOUtils.writeFile(file, value, true);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public List<ApplicationFile> listFiles(PathFilter filter) {
+ if ( ! isDirectory()) return Collections.emptyList();
+ return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString())))
+ .map(f -> new StoringApplicationPackageFile(asApplicationRelativePath(f),
+ root))
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public ApplicationFile delete() {
+ file.delete();
+ return this;
+ }
+
+ @Override
+ public MetaData getMetaData() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int compareTo(ApplicationFile other) {
+ return this.getPath().getName().compareTo((other).getPath().getName());
+ }
+
+ /** Strips the application package root path prefix from the path of the given file */
+ private Path asApplicationRelativePath(File file) {
+ Path path = Path.fromString(file.toString());
+
+ Iterator<String> pathIterator = path.iterator();
+ // Skip the path elements this shares with the root
+ for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) {
+ String rootElement = rootIterator.next();
+ String pathElement = pathIterator.next();
+ if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken");
+ }
+ // Build a path from the remaining
+ Path relative = Path.fromString("");
+ while (pathIterator.hasNext())
+ relative = relative.append(pathIterator.next());
+ return relative;
+ }
+
+ }
+
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index ca5cdc3cc56..50e83fd1a1e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -41,8 +41,8 @@ public class ModelsEvaluator extends AbstractComponent {
Model requireModel(String name) {
Model model = models.get(name);
if (model == null)
- throw new IllegalArgumentException("No model named '" + name + "'. Available models: " +
- String.join(", ", models.keySet()));
+ throw new IllegalArgumentException("No model named '" + name + ". Available models: " +
+ models.keySet().stream().collect(Collectors.joining(", ")));
return model;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index f7fe91cb56f..6716993e1dd 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -98,33 +98,33 @@ public class ImportedModel {
void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
/**
- * Returns all the output expressions of this indexed by name. The names consist of one or two parts
- * separated by dot, where the first part is the signature name
+ * Returns all the outputs of this by name. The names consist of one to three parts
+ * separated by dot, where the first part is the model name, the second is the signature name
* if signatures are used, or the expression name if signatures are not used and there are multiple
- * expressions, and the second is the output name if signature names are used.
+ * expressions, and the third is the output name if signature names are used.
*/
public List<Pair<String, RankingExpression>> outputExpressions() {
- List<Pair<String, RankingExpression>> expressions = new ArrayList<>();
+ List<Pair<String, RankingExpression>> names = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
- expressions().get(outputEntry.getValue())));
+ names.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
+ expressions().get(outputEntry.getValue())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
- expressions.add(new Pair<>(signatureEntry.getKey(),
- expressions().get(signatureEntry.getKey())));
+ names.add(new Pair<>(signatureEntry.getKey(),
+ expressions().get(signatureEntry.getKey())));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
- Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- expressions.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue()));
+ Map.Entry<String, RankingExpression> singleEntry = expressions.entrySet().iterator().next();
+ names.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue()));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- expressions.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue()));
+ names.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue()));
}
}
}
- return expressions;
+ return names;
}
/**
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
index 32d33622a33..b1714b49256 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
@@ -10,9 +10,7 @@ import java.util.Collection;
import java.util.Optional;
/**
- * All models imported from the models/ directory in the application package.
- * If this is empty it may be due to either not having any models in the application package,
- * or this being created for a ZooKeeper application package, which does not have imported models.
+ * All models imported from the models/ directory in the application package
*
* @author bratseth
*/
@@ -56,22 +54,16 @@ public class ImportedModels {
}
/**
- * Returns the model at the given location in the application package.
+ * Returns the model at the given location in the application package (lazily loaded),
*
- * @param modelPath the path to this model (file or directory, depending on model type)
- * under the application package, both from the root or relative to the
- * models directory works
- * @return the model at this path or null if none
+ * @param modelPath the full path to this model (file or directory, depending on model type)
+ * under the application package
+ * @throws IllegalArgumentException if the model cannot be loaded
*/
public ImportedModel get(File modelPath) {
- System.out.println("Name from " + modelPath + ": " + toName(modelPath));
return importedModels.get(toName(modelPath));
}
- public ImportedModel get(String modelName) {
- return importedModels.get(modelName);
- }
-
/** Returns an immutable collection of all the imported models */
public Collection<ImportedModel> all() {
return importedModels.values();