summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-22 14:36:03 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-22 14:36:03 +0200
commit4eb133b40206e20e3a70dae7aacec0f6b117e15d (patch)
tree0bb6fb1c306da10125efc5a5c0d6eec6deae0c22
parent7392f9fdbee5f0a52ac9c056376b659b32500c60 (diff)
Scope imported models to an entire application build
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java5
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java8
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java21
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java20
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java16
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java10
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java49
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java6
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java7
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java16
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java13
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java13
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java7
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java7
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java33
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java70
34 files changed, 234 insertions, 153 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
index a71a0878d3d..f926259f115 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
@@ -231,9 +231,8 @@ public interface ApplicationPackage {
*/
ApplicationMetaData getMetaData();
- default File getFileReference(Path pathRelativeToAppDir) {
- throw new UnsupportedOperationException("This application package cannot return file references");
- }
+ File getFileReference(Path pathRelativeToAppDir);
+
default void validateXML() throws IOException {
throw new UnsupportedOperationException("This application package cannot validate XML");
}
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index ff6370f1738..574c25a2f84 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.config.model.deploy;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.component.Version;
import com.yahoo.component.Vtag;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -22,8 +21,8 @@ import com.yahoo.config.provision.Zone;
import com.yahoo.io.reader.NamedReader;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.vespa.config.ConfigDefinition;
import com.yahoo.vespa.config.ConfigDefinitionBuilder;
import com.yahoo.vespa.config.ConfigDefinitionKey;
@@ -67,7 +66,7 @@ public class DeployState implements ConfigDefinitionStore {
private final Zone zone;
private final QueryProfiles queryProfiles;
private final SemanticRules semanticRules;
- //private final ImmutableMap<String, ImportedModel> importedMlModels;
+ private final ImportedModels importedModels;
private final ValidationOverrides validationOverrides;
private final Version wantedNodeVespaVersion;
private final Instant now;
@@ -101,6 +100,7 @@ public class DeployState implements ConfigDefinitionStore {
this.zone = zone;
this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated
this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated
+ this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR));
this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty);
this.wantedNodeVespaVersion = wantedNodeVespaVersion;
@@ -215,7 +215,7 @@ public class DeployState implements ConfigDefinitionStore {
public SemanticRules getSemanticRules() { return semanticRules; }
/** The (machine learned) models imported from the models/ directory, as an unmodifiable map indexed by model name */
- //public Map<String, ImportedModel> importedMlModels() { return importedMlModels; }
+ public ImportedModels getImportedModels() { return importedModels; }
public Version getWantedNodeVespaVersion() { return wantedNodeVespaVersion; }
diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
index 0cfde3c655c..7404ae14a5d 100644
--- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
+++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
@@ -32,6 +32,7 @@ public class MockApplicationPackage implements ApplicationPackage {
public static final String MUSIC_SEARCHDEFINITION = createSearchDefinition("music", "foo");
public static final String BOOK_SEARCHDEFINITION = createSearchDefinition("book", "bar");
+ private final File root;
private final String hostsS;
private final String servicesS;
private final List<String> searchDefinitions;
@@ -42,9 +43,11 @@ public class MockApplicationPackage implements ApplicationPackage {
private final QueryProfileRegistry queryProfileRegistry;
private final ApplicationMetaData applicationMetaData;
- protected MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir,
+ protected MockApplicationPackage(File root, String hosts, String services, List<String> searchDefinitions,
+ String searchDefinitionDir,
String deploymentSpec, String validationOverrides, boolean failOnValidateXml,
String queryProfile, String queryProfileType) {
+ this.root = root;
this.hostsS = hosts;
this.servicesS = services;
this.searchDefinitions = searchDefinitions;
@@ -57,6 +60,9 @@ public class MockApplicationPackage implements ApplicationPackage {
applicationMetaData = new ApplicationMetaData("user", "dir", 0L, false, "application", "checksum", 0L, 0L);
}
+ /** Returns the root of this application package relative to the current dir */
+ protected File root() { return root; }
+
@Override
public String getApplicationName() {
return "mock application";
@@ -111,6 +117,11 @@ public class MockApplicationPackage implements ApplicationPackage {
}
@Override
+ public File getFileReference(Path path) {
+ return Path.fromString(root.toString()).append(path).toFile();
+ }
+
+ @Override
public String getHostSource() {
return "mock source";
}
@@ -163,6 +174,7 @@ public class MockApplicationPackage implements ApplicationPackage {
public static class Builder {
+ private File root = new File("nonexisting");
private String hosts = null;
private String services = null;
private List<String> searchDefinitions = Collections.emptyList();
@@ -176,6 +188,11 @@ public class MockApplicationPackage implements ApplicationPackage {
public Builder() {
}
+ public Builder withRoot(File root) {
+ this.root = root;
+ return this;
+ }
+
public Builder withEmptyHosts() {
return this.withHosts(emptyHosts);
}
@@ -235,7 +252,7 @@ public class MockApplicationPackage implements ApplicationPackage {
}
public ApplicationPackage build() {
- return new MockApplicationPackage(hosts, services, searchDefinitions, searchDefinitionDir,
+ return new MockApplicationPackage(root, hosts, services, searchDefinitions, searchDefinitionDir,
deploymentSpec, validationOverrides, failOnValidateXml,
queryProfile, queryProfileType);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 9fec1983465..2e66784527d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -682,10 +683,10 @@ public class RankProfile implements Serializable, Cloneable {
* Returns a copy of this where the content is optimized for execution.
* Compiled profiles should never be modified.
*/
- public RankProfile compile(QueryProfileRegistry queryProfiles) {
+ public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedModels importedModels) {
try {
RankProfile compiled = this.clone();
- compiled.compileThis(queryProfiles);
+ compiled.compileThis(queryProfiles, importedModels);
return compiled;
}
catch (IllegalArgumentException e) {
@@ -693,19 +694,19 @@ public class RankProfile implements Serializable, Cloneable {
}
}
- private void compileThis(QueryProfileRegistry queryProfiles) {
+ private void compileThis(QueryProfileRegistry queryProfiles, ImportedModels importedModels) {
parseExpressions();
checkNameCollisions(getMacros(), getConstants());
ExpressionTransforms expressionTransforms = new ExpressionTransforms();
// Macro compiling first pass: compile inline macros without resolving other macros
- Map<String, Macro> inlineMacros = compileMacros(getInlineMacros(), queryProfiles, Collections.emptyMap(), expressionTransforms);
+ Map<String, Macro> inlineMacros = compileMacros(getInlineMacros(), queryProfiles, importedModels, Collections.emptyMap(), expressionTransforms);
// Macro compiling second pass: compile all macros and insert previously compiled inline macros
- macros = compileMacros(getMacros(), queryProfiles, inlineMacros, expressionTransforms);
+ macros = compileMacros(getMacros(), queryProfiles, importedModels, inlineMacros, expressionTransforms);
- firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, getConstants(), inlineMacros, expressionTransforms);
- secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, getConstants(), inlineMacros, expressionTransforms);
+ firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms);
+ secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms);
}
private void checkNameCollisions(Map<String, Macro> macros, Map<String, Value> constants) {
@@ -723,12 +724,13 @@ public class RankProfile implements Serializable, Cloneable {
private Map<String, Macro> compileMacros(Map<String, Macro> macros,
QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels,
Map<String, Macro> inlineMacros,
ExpressionTransforms expressionTransforms) {
Map<String, Macro> compiledMacros = new LinkedHashMap<>();
for (Map.Entry<String, Macro> entry : macros.entrySet()) {
Macro macro = entry.getValue().clone();
- RankingExpression exp = compile(macro.getRankingExpression(), queryProfiles, getConstants(), inlineMacros, expressionTransforms);
+ RankingExpression exp = compile(macro.getRankingExpression(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms);
macro.setRankingExpression(exp);
compiledMacros.put(entry.getKey(), macro);
}
@@ -737,6 +739,7 @@ public class RankProfile implements Serializable, Cloneable {
private RankingExpression compile(RankingExpression expression,
QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels,
Map<String, Value> constants,
Map<String, Macro> inlineMacros,
ExpressionTransforms expressionTransforms) {
@@ -745,6 +748,7 @@ public class RankProfile implements Serializable, Cloneable {
RankProfileTransformContext context = new RankProfileTransformContext(this,
queryProfiles,
+ importedModels,
constants,
inlineMacros,
rankPropertiesOutput);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
index 985087e905b..4af26b72817 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java
@@ -12,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.derived.validation.Validation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import java.io.IOException;
import java.io.Writer;
@@ -46,8 +47,11 @@ public class DerivedConfiguration {
* modified.
* @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry}
*/
- public DerivedConfiguration(Search search, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles) {
- this(search, null, new BaseDeployLogger(), rankProfileRegistry, queryProfiles);
+ public DerivedConfiguration(Search search,
+ RankProfileRegistry rankProfileRegistry,
+ QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels) {
+ this(search, null, new BaseDeployLogger(), rankProfileRegistry, queryProfiles, importedModels);
}
/**
@@ -63,10 +67,12 @@ public class DerivedConfiguration {
* @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry}
* @param queryProfiles the query profiles of this application
*/
- public DerivedConfiguration(Search search, List<Search> abstractSearchList,
+ public DerivedConfiguration(Search search,
+ List<Search> abstractSearchList,
DeployLogger deployLogger,
RankProfileRegistry rankProfileRegistry,
- QueryProfileRegistry queryProfiles) {
+ QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels) {
Validator.ensureNotNull("Search definition", search);
if ( ! search.isProcessed()) {
throw new IllegalArgumentException("Search '" + search.getName() + "' not processed.");
@@ -88,7 +94,7 @@ public class DerivedConfiguration {
summaries = new Summaries(search, deployLogger);
summaryMap = new SummaryMap(search, summaries);
juniperrc = new Juniperrc(search);
- rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles);
+ rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles, importedModels);
indexingScript = new IndexingScript(search);
indexInfo = new IndexInfo(search);
indexSchema = new IndexSchema(search);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 77645331d9e..1e978e43d6a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -3,6 +3,7 @@ package com.yahoo.searchdefinition.derived;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.Search;
@@ -26,24 +27,27 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
public RankProfileList(Search search,
AttributeFields attributeFields,
RankProfileRegistry rankProfileRegistry,
- QueryProfileRegistry queryProfiles) {
+ QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels) {
setName(search.getName());
- deriveRankProfiles(rankProfileRegistry, queryProfiles, search, attributeFields);
+ deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields);
}
private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry,
QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels,
Search search,
AttributeFields attributeFields) {
RawRankProfile defaultProfile = new RawRankProfile(rankProfileRegistry.getRankProfile(search, "default"),
queryProfiles,
+ importedModels,
attributeFields);
rankProfiles.put(defaultProfile.getName(), defaultProfile);
for (RankProfile rank : rankProfileRegistry.localRankProfiles(search)) {
if ("default".equals(rank.getName())) continue;
- RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, attributeFields);
+ RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields);
rankProfiles.put(rawRank.getName(), rawRank);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index b11d5962713..0d104a97698 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -9,6 +9,7 @@ import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
@@ -47,9 +48,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
/**
* Creates a raw rank profile from the given rank profile
*/
- public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, AttributeFields attributeFields) {
+ public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) {
this.name = rankProfile.getName();
- compressedProperties = compress(removePartFromKeys(new Deriver(rankProfile, queryProfiles, attributeFields).derive()));
+ compressedProperties = compress(removePartFromKeys(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields).derive()));
}
private List<Pair<String, String>> removePartFromKeys(Map<String, String> map) {
@@ -155,8 +156,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
/**
* Creates a raw rank profile from the given rank profile
*/
- public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, AttributeFields attributeFields) {
- RankProfile compiled = rankProfile.compile(queryProfiles);
+ public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) {
+ RankProfile compiled = rankProfile.compile(queryProfiles, importedModels);
attributeTypes = compiled.getAttributeTypes();
queryFeatureTypes = compiled.getQueryFeatureTypes();
deriveRankingFeatures(compiled);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 21e51b68677..a38fbe1aaa0 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -77,14 +78,12 @@ public class ConvertedModel {
* Create a converted model for a rank profile given from either an imported model,
* or (if unavailable) from stored application package data.
*/
- public ConvertedModel(Path modelPath,
- RankProfileTransformContext context,
- ImportedModels importedModels) {
+ public ConvertedModel(Path modelPath, RankProfileTransformContext context) {
this.modelPath = modelPath;
this.modelName = toModelName(modelPath);
ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath);
if ( store.hasSourceModel())
- expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), importedModels);
+ expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels());
else
expressions = transformFromStoredModel(store, context.rankProfile());
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
deleted file mode 100644
index 7cf0a5d8b76..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchdefinition.expressiontransforms;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
-import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
-
-import java.io.File;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * All models imported from the models/ directory in the application package
- *
- * @author bratseth
- */
-class ImportedModels {
-
- /** The cache of already imported models */
- private final Map<String, ImportedModel> importedModels = new HashMap<>();
-
- private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter());
-
- ImportedModels() {
- }
-
- /**
- * Returns the model at the given location in the application package (lazily loaded),
- *
- * @param modelPath the full path to this model (file or directory, depending on model type)
- * under the application package
- * @throws IllegalArgumentException if the model cannot be loaded
- */
- public ImportedModel get(File modelPath) {
- String modelName = toName(modelPath);
- ModelImporter importer = importers.stream().filter(item -> item.canImport(modelPath.toString())).findFirst().get();
- return importedModels.computeIfAbsent(modelName, __ -> importer.importModel(modelName, modelPath));
- }
-
- private static String toName(File modelPath) {
- Path localPath = Path.fromString(modelPath.toString()).getChildPath();
- return localPath.toString().replace("/", "_").replace('.', '_');
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 0b68a67acff..36dc200f3c9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -3,13 +3,13 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import java.io.File;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
@@ -24,8 +24,6 @@ import java.util.Map;
*/
public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private final ImportedModels importedOnnxModels = new ImportedModels();
-
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
private final Map<Path, ConvertedModel> convertedOnnxModels = new HashMap<>();
@@ -45,7 +43,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedOnnxModels));
+ convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
index 5da5b3dabda..c7b4e85d74e 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import java.util.Map;
@@ -17,23 +18,27 @@ public class RankProfileTransformContext extends TransformContext {
private final RankProfile rankProfile;
private final QueryProfileRegistry queryProfiles;
+ private final ImportedModels importedModels;
private final Map<String, RankProfile.Macro> inlineMacros;
private final Map<String, String> rankPropertiesOutput;
public RankProfileTransformContext(RankProfile rankProfile,
QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels,
Map<String, Value> constants,
Map<String, RankProfile.Macro> inlineMacros,
Map<String, String> rankPropertiesOutput) {
super(constants);
this.rankProfile = rankProfile;
this.queryProfiles = queryProfiles;
+ this.importedModels = importedModels;
this.inlineMacros = inlineMacros;
this.rankPropertiesOutput = rankPropertiesOutput;
}
public RankProfile rankProfile() { return rankProfile; }
public QueryProfileRegistry queryProfiles() { return queryProfiles; }
+ public ImportedModels importedModels() { return importedModels; }
public Map<String, RankProfile.Macro> inlineMacros() { return inlineMacros; }
public Map<String, String> rankPropertiesOutput() { return rankPropertiesOutput; }
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 4f15fb5a291..619c13da764 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -2,13 +2,13 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import java.io.File;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
@@ -22,8 +22,6 @@ import java.util.Map;
*/
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private final ImportedModels importedTensorFlowModels = new ImportedModels();
-
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>();
@@ -43,7 +41,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedTensorFlowModels));
+ convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
index 6e0aa562508..c7762f09851 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
@@ -289,9 +289,12 @@ public class IndexedSearchCluster extends SearchCluster
if ( ! (search instanceof UnproperSearch)) {
DocumentDatabase db = new DocumentDatabase(this,
search.getName(),
- new DerivedConfiguration(search, globalSearches, deployLogger(),
+ new DerivedConfiguration(search,
+ globalSearches,
+ deployLogger(),
getRoot().getDeployState().rankProfileRegistry(),
- getRoot().getDeployState().getQueryProfiles().getRegistry()));
+ getRoot().getDeployState().getQueryProfiles().getRegistry(),
+ getRoot().getDeployState().getImportedModels()));
// TODO: remove explicit adding of user configs when the complete content model is built using builders.
db.mergeUserConfigs(spec.getUserConfigs());
documentDbs.add(db);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java
index e87df3d530b..66553ebadf6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java
@@ -104,7 +104,8 @@ public class StreamingSearchCluster extends SearchCluster implements
}
this.sdConfig = new DerivedConfiguration(localSearch, globalSearches, deployLogger(),
getRoot().getDeployState().rankProfileRegistry(),
- getRoot().getDeployState().getQueryProfiles().getRegistry());
+ getRoot().getDeployState().getQueryProfiles().getRegistry(),
+ getRoot().getDeployState().getImportedModels());
}
@Override
public DerivedConfiguration getSdConfig() {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
index bff34411d44..03fa92f5cb9 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Test;
import java.io.IOException;
@@ -23,7 +24,7 @@ public class IncorrectRankingExpressionFileRefTestCase extends SearchDefinitionT
Search search = SearchBuilder.buildFromFile("src/test/examples/incorrectrankingexpressionfileref.sd",
registry,
new QueryProfileRegistry());
- new DerivedConfiguration(search, registry, new QueryProfileRegistry()); // cause rank profile parsing
+ new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing
fail("parsing should have failed");
} catch (IllegalArgumentException e) {
e.printStackTrace();
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index ed612039fb7..de9df08f5c0 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -16,6 +16,7 @@ import com.yahoo.searchdefinition.document.RankType;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
@@ -91,7 +92,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
assertEquals(8, rankProfile.getNumThreadsPerSearch());
assertEquals(70, rankProfile.getMinHitsPerThread());
assertEquals(1200, rankProfile.getNumSearchPartitions());
- RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), attributeFields);
+ RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), attributeFields);
assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").isPresent());
assertEquals("0.78", findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").get());
assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.numthreadspersearch").isPresent());
@@ -126,7 +127,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
}
private static void assertAttributeTypeSettings(RankProfile profile, Search search) {
- RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new AttributeFields(search));
+ RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search));
assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.a").get());
assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.attribute.b").get());
assertEquals("tensor(x[])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.c").get());
@@ -168,7 +169,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
}
private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) {
- RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new AttributeFields(search));
+ RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search));
assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor1").get());
assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor2").get());
assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent());
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
index 15ddef60807..3a2482b56d0 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java
@@ -7,6 +7,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Test;
import java.util.ArrayList;
@@ -60,7 +61,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase {
assertEquals("query(a) = 1500", parent.getRankProperties().get(0).toString());
// Check derived model
- RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), attributeFields);
+ RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedModels(), attributeFields);
assertEquals("(query(a),1500)", rawParent.configProperties().get(0).toString());
}
@@ -72,6 +73,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase {
// Check derived model
RawRankProfile rawChild = new RawRankProfile(rankProfileRegistry.getRankProfile(search, "child"),
new QueryProfileRegistry(),
+ new ImportedModels(),
attributeFields);
assertEquals("(query(a),2000)", rawChild.configProperties().get(0).toString());
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index 82b9f5ac043..da546967dc1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
import com.yahoo.config.model.application.provider.BaseDeployLogger;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.yolean.Exceptions;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
@@ -69,18 +70,19 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(queryProfileRegistry);
+ RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(queryProfileRegistry, new ImportedModels());
assertEquals("0.0", parent.getFirstPhaseRanking().getRoot().toString());
- RankProfile child1 = rankProfileRegistry.getRankProfile(s, "child1").compile(queryProfileRegistry);
+ RankProfile child1 = rankProfileRegistry.getRankProfile(s, "child1").compile(queryProfileRegistry, new ImportedModels());
assertEquals("6.5", child1.getFirstPhaseRanking().getRoot().toString());
assertEquals("11.5", child1.getSecondPhaseRanking().getRoot().toString());
- RankProfile child2 = rankProfileRegistry.getRankProfile(s, "child2").compile(queryProfileRegistry);
+ RankProfile child2 = rankProfileRegistry.getRankProfile(s, "child2").compile(queryProfileRegistry, new ImportedModels());
assertEquals("16.6", child2.getFirstPhaseRanking().getRoot().toString());
assertEquals("foo: 14.0", child2.getMacros().get("foo").getRankingExpression().toString());
List<Pair<String, String>> rankProperties = new RawRankProfile(child2,
queryProfileRegistry,
+ new ImportedModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString());
assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString());
@@ -111,7 +113,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
try {
- rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
fail("Should have caused an exception");
}
catch (IllegalArgumentException e) {
@@ -174,7 +176,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
profile.parseExpressions(); // TODO: Do differently
assertEquals("safeLog(popShareSlowDecaySignal,myValue)", profile.getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString());
assertEquals("safeLog(popShareSlowDecaySignal,-9.21034037)",
- profile.compile(new QueryProfileRegistry()).getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString());
+ profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString());
}
@Test
@@ -197,7 +199,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
Search s = builder.getSearch();
RankProfile profile = rankProfileRegistry.getRankProfile(s, "test");
assertEquals("k1 + (k2 + k3) / 100000000.0",
- profile.compile(new QueryProfileRegistry()).getMacros().get("rank_default").getRankingExpression().getRoot().toString());
+ profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("rank_default").getRankingExpression().getRoot().toString());
}
@Test
@@ -223,7 +225,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
Search s = builder.getSearch();
RankProfile profile = rankProfileRegistry.getRankProfile(s, "test");
assertEquals("0.5 + 50 * (attribute(rating_yelp) - 3)",
- profile.compile(new QueryProfileRegistry()).getMacros().get("rank_default").getRankingExpression().getRoot().toString());
+ profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("rank_default").getRankingExpression().getRoot().toString());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
index a0dd18aeea9..555aa698c65 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
@@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Test;
import java.util.Optional;
@@ -61,10 +62,10 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
- RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry());
+ RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels());
assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))",
parent.getFirstPhaseRanking().getRoot().toString());
- RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry());
+ RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry(), new ImportedModels());
assertEquals("7.0 * (9 + attribute(a))",
child.getFirstPhaseRanking().getRoot().toString());
}
@@ -121,14 +122,14 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
- RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry());
+ RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels());
assertEquals("17.0", parent.getFirstPhaseRanking().getRoot().toString());
assertEquals("0.0", parent.getSecondPhaseRanking().getRoot().toString());
assertEquals("10.0", getRankingExpression("foo", parent, s));
assertEquals("17.0", getRankingExpression("firstphase", parent, s));
assertEquals("0.0", getRankingExpression("secondphase", parent, s));
- RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry());
+ RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry(), new ImportedModels());
assertEquals("31.0 + bar + arg(4.0)", child.getFirstPhaseRanking().getRoot().toString());
assertEquals("24.0", child.getSecondPhaseRanking().getRoot().toString());
assertEquals("12.0", getRankingExpression("foo", child, s));
@@ -177,7 +178,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
assertEquals("attribute(a) + C + (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString());
assertEquals("attribute(a) + attribute(b)", getRankingExpression("C", test, s));
assertEquals("attribute(b) + 1", getRankingExpression("D", test, s));
@@ -208,7 +209,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase
private String getRankingExpression(String name, RankProfile rankProfile, Search search) {
Optional<String> rankExpression =
- new RawRankProfile(rankProfile, new QueryProfileRegistry(), new AttributeFields(search))
+ new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search))
.configProperties()
.stream()
.filter(r -> r.getFirst().equals("rankingExpression(" + name + ").rankingScript"))
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index ed1b00e2875..2bd7d3031a5 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -9,6 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Ignore;
import org.junit.Test;
@@ -44,9 +45,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
new QueryProfileRegistry(),
+ new ImportedModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(sin).rankingScript,x * x)",
testRankProperties.get(0).toString());
@@ -87,9 +89,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
new QueryProfileRegistry(),
+ new ImportedModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(tan).rankingScript,x * x)",
testRankProperties.get(0).toString());
@@ -136,9 +139,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
new QueryProfileRegistry(),
+ new ImportedModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(sin).rankingScript,x * x)",
testRankProperties.get(0).toString());
@@ -199,9 +203,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles, new ImportedModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
queryProfiles,
+ new ImportedModels(),
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
testRankProperties.get(0).toString());
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
index a07fea69592..3fe3a7c3de1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Ignore;
import org.junit.Test;
@@ -24,7 +25,7 @@ public class RankingExpressionValidationTestCase extends SearchDefinitionTestCas
try {
RankProfileRegistry registry = new RankProfileRegistry();
Search search = importWithExpression(expression, registry);
- new DerivedConfiguration(search, registry, new QueryProfileRegistry()); // cause rank profile parsing
+ new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing
fail("No exception on incorrect ranking expression " + expression);
} catch (IllegalArgumentException e) {
// Success
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
index 3d2bce62713..88a02cc7a93 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java
@@ -8,6 +8,7 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.vespa.configmodel.producers.DocumentManager;
import com.yahoo.vespa.configmodel.producers.DocumentTypes;
@@ -80,14 +81,16 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase
protected DerivedConfiguration derive(String dirName, String searchDefinitionName, SearchBuilder builder) throws IOException {
DerivedConfiguration config = new DerivedConfiguration(builder.getSearch(searchDefinitionName),
builder.getRankProfileRegistry(),
- builder.getQueryProfileRegistry());
+ builder.getQueryProfileRegistry(),
+ new ImportedModels());
return export(dirName, builder, config);
}
protected DerivedConfiguration derive(String dirName, SearchBuilder builder, Search search) throws IOException {
DerivedConfiguration config = new DerivedConfiguration(search,
builder.getRankProfileRegistry(),
- builder.getQueryProfileRegistry());
+ builder.getQueryProfileRegistry(),
+ new ImportedModels());
return export(dirName, builder, config);
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
index 21467776ad9..f4344c9b03c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java
@@ -10,6 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Test;
import java.io.IOException;
@@ -34,7 +35,7 @@ public class EmptyRankProfileTestCase extends SearchDefinitionTestCase {
doc.addField(new SDField("c", DataType.STRING));
search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry());
- new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry());
+ new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
index 59d2d9879d1..dec4b734f27 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java
@@ -11,6 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
@@ -41,7 +42,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase {
other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333));
Processing.process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true);
- DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry());
+ DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
// Check attribute fields
derived.getAttributeFields(); // TODO: assert content
@@ -72,7 +73,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase {
other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333));
search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry());
- DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry());
+ DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedModels());
// Check il script addition
assertIndexing(Arrays.asList("clear_state | guard { input a | tokenize normalize stem:\"SHORTEST\" | index a; }",
@@ -99,7 +100,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase {
field2.setLiteralBoost(20);
search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry());
- new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry());
+ new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
assertIndexing(Arrays.asList("clear_state | guard { input title | tokenize normalize stem:\"SHORTEST\" | summary title | index title; }",
"clear_state | guard { input body | tokenize normalize stem:\"SHORTEST\" | summary body | index body; }",
"clear_state | guard { input title | tokenize | index title_literal; }",
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
index f4edc1dd0ae..723cd58a34a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java
@@ -5,6 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Test;
import java.io.File;
@@ -34,7 +35,8 @@ public class SimpleInheritTestCase extends AbstractExportingTestCase {
DerivedConfiguration config = new DerivedConfiguration(search,
builder.getRankProfileRegistry(),
- new QueryProfileRegistry());
+ new QueryProfileRegistry(),
+ new ImportedModels());
config.export(toDirName);
checkDir(toDirName, expectedResultsDirName);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
index fd029c1df05..26b100a2d96 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java
@@ -10,6 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.processing.Processing;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
@@ -33,7 +34,7 @@ public class TypeConversionTestCase extends SearchDefinitionTestCase {
document.addField(a);
Processing.process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true);
- DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry());
+ DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels());
IndexInfo indexInfo = derived.getIndexInfo();
assertFalse(indexInfo.hasCommand("default", "compact-to-term"));
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
index d743f60201e..d38bce04617 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java
@@ -9,6 +9,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.derived.Deriver;
import com.yahoo.searchdefinition.document.SDDocumentType;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Ignore;
import org.junit.Test;
import java.io.File;
@@ -99,7 +100,7 @@ public class ImplicitSearchFieldsTestCase extends SearchDefinitionTestCase {
sb.importFile("src/test/examples/nextgen/simple.sd");
sb.build();
assertNotNull(sb.getSearch());
- new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry());
+ new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedModels());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index ab689b88993..45cdbfa9c1f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -3,12 +3,14 @@ package com.yahoo.searchdefinition.processing;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
+import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import java.util.HashMap;
import java.util.List;
@@ -76,7 +78,11 @@ class RankProfileSearchFixture {
}
public RankProfile compileRankProfile(String rankProfile) {
- RankProfile compiled = rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry);
+ return compileRankProfile(rankProfile, Path.fromString("nonexistinng"));
+ }
+
+ public RankProfile compileRankProfile(String rankProfile, Path applicationDir) {
+ RankProfile compiled = rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile()));
compiledRankProfiles.put(rankProfile, compiled);
return compiled;
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index a311a2ed706..90137ddde49 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -127,7 +127,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: onnx('mnist_softmax.onnx')" +
" }\n" +
" }");
- search.compileRankProfile("my_profile");
+ search.compileRankProfile("my_profile", applicationDir.append("models"));
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
@@ -226,7 +226,7 @@ public class RankingExpressionWithOnnxTestCase {
String vespaExpressionWithoutConstant =
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_Variable, f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
- search.compileRankProfile("my_profile");
+ search.compileRankProfile("my_profile", applicationDir.append("models"));
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
@@ -241,7 +241,7 @@ public class RankingExpressionWithOnnxTestCase {
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication);
- searchFromStored.compileRankProfile("my_profile");
+ searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
searchFromStored.search().getRankingConstants().get( name + "_Variable"));
@@ -317,7 +317,7 @@ public class RankingExpressionWithOnnxTestCase {
" }",
constant,
field);
- fixture.compileRankProfile("my_profile");
+ fixture.compileRankProfile("my_profile", applicationDir.append("models"));
return fixture;
}
catch (ParseException e) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index a3a286350b5..2804b92767a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -163,7 +163,7 @@ public class RankingExpressionWithTensorFlowTestCase {
" expression: tensorflow('mnist_softmax/saved')" +
" }\n" +
" }");
- search.compileRankProfile("my_profile");
+ search.compileRankProfile("my_profile", applicationDir.append("models"));
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
@@ -280,8 +280,8 @@ public class RankingExpressionWithTensorFlowTestCase {
String vespaExpressionWithoutConstant =
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(" + name + "_layer_Variable_1_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir));
- search.compileRankProfile("my_profile");
- search.compileRankProfile("my_profile_child");
+ search.compileRankProfile("my_profile", applicationDir.append("models"));
+ search.compileRankProfile("my_profile_child", applicationDir.append("models"));
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
@@ -297,8 +297,8 @@ public class RankingExpressionWithTensorFlowTestCase {
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication);
- searchFromStored.compileRankProfile("my_profile");
- searchFromStored.compileRankProfile("my_profile_child");
+ searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
+ searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child");
assertNull("Constant overridden by macro is not added",
@@ -358,8 +358,8 @@ public class RankingExpressionWithTensorFlowTestCase {
final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_" + name + "_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir));
- search.compileRankProfile("my_profile");
- search.compileRankProfile("my_profile_child");
+ search.compileRankProfile("my_profile", applicationDir.append("models"));
+ search.compileRankProfile("my_profile_child", applicationDir.append("models"));
search.assertFirstPhaseExpression(expression, "my_profile");
search.assertFirstPhaseExpression(expression, "my_profile_child");
assertSmallConstant(name + "_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
@@ -376,8 +376,8 @@ public class RankingExpressionWithTensorFlowTestCase {
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication);
- searchFromStored.compileRankProfile("my_profile");
- searchFromStored.compileRankProfile("my_profile_child");
+ searchFromStored.compileRankProfile("my_profile", applicationDir.append("models"));
+ searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models"));
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child");
assertSmallConstant(name + "_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
@@ -453,7 +453,7 @@ public class RankingExpressionWithTensorFlowTestCase {
" }",
constant,
field);
- fixture.compileRankProfile("my_profile");
+ fixture.compileRankProfile("my_profile", applicationDir.append("models"));
return fixture;
}
catch (ParseException e) {
@@ -473,26 +473,19 @@ public class RankingExpressionWithTensorFlowTestCase {
static class StoringApplicationPackage extends MockApplicationPackage {
- private final File root;
-
StoringApplicationPackage(Path applicationPackageWritableRoot) {
this(applicationPackageWritableRoot, null, null);
}
StoringApplicationPackage(Path applicationPackageWritableRoot, String queryProfile, String queryProfileType) {
- super(null, null, Collections.emptyList(), null,
+ super(new File(applicationPackageWritableRoot.toString()),
+ null, null, Collections.emptyList(), null,
null, null, false, queryProfile, queryProfileType);
- this.root = new File(applicationPackageWritableRoot.toString());
- }
-
- @Override
- public File getFileReference(Path path) {
- return Path.fromString(root.toString()).append(path).toFile();
}
@Override
public ApplicationFile getFile(Path file) {
- return new StoringApplicationPackageFile(file, Path.fromString(root.toString()));
+ return new StoringApplicationPackageFile(file, Path.fromString(root().toString()));
}
@Override
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
index f98783ad671..2e109553560 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
@@ -46,7 +46,7 @@ public class RankingExpressionWithXgboostTestCase {
" }",
constant,
field);
- fixture.compileRankProfile("my_profile");
+ fixture.compileRankProfile("my_profile", applicationDir);
return fixture;
} catch (ParseException e) {
throw new IllegalArgumentException(e);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index 18c3e43ae7e..31ceb97ab50 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -8,6 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Test;
import java.io.IOException;
@@ -39,6 +40,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase {
List<Pair<String, String>> rankProperties = new RawRankProfile(macrosRankProfile,
new QueryProfileRegistry(),
+ new ImportedModels(),
new AttributeFields(search)).configProperties();
assertEquals(6, rankProperties.size());
@@ -64,7 +66,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase {
Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressioninfile",
registry,
new QueryProfileRegistry()).getSearch();
- new DerivedConfiguration(search, registry, new QueryProfileRegistry()); // rank profile parsing happens during deriving
+ new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // rank profile parsing happens during deriving
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index 2b2d72dcf34..6b287c77a10 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,6 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import org.junit.Test;
import java.util.List;
@@ -199,9 +200,10 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
"}\n");
builder.build(true, new BaseDeployLogger());
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles, new ImportedModels());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
queryProfiles,
+ new ImportedModels(),
new AttributeFields(s)).configProperties();
return testRankProperties;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
new file mode 100644
index 00000000000..f8bc83735c7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
@@ -0,0 +1,70 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.path.Path;
+
+import java.io.File;
+import java.util.Optional;
+
+/**
+ * All models imported from the models/ directory in the application package
+ *
+ * @author bratseth
+ */
+public class ImportedModels {
+
+ /** All imported models, indexed by their names */
+ private final ImmutableMap<String, ImportedModel> importedModels;
+
+ private static final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter());
+
+ /** Create a null imported models */
+ public ImportedModels() {
+ importedModels = ImmutableMap.of();
+ }
+
+ public ImportedModels(File modelsDirectory) {
+ ImmutableMap.Builder<String, ImportedModel> builder = new ImmutableMap.Builder<>();
+
+ // Find all subdirectories recursively which contains a model we can read
+ importRecursively(modelsDirectory, builder);
+ importedModels = builder.build();
+ }
+
+ private static void importRecursively(File dir, ImmutableMap.Builder<String, ImportedModel> builder) {
+ if ( ! dir.isDirectory()) return;
+ for (File child : dir.listFiles()) {
+ Optional<ModelImporter> importer = findImporterOf(child);
+ if (importer.isPresent()) {
+ String name = toName(child);
+ builder.put(name, importer.get().importModel(name, child));
+ }
+ else {
+ importRecursively(child, builder);
+ }
+ }
+ }
+
+ private static Optional<ModelImporter> findImporterOf(File path) {
+ return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
+ }
+
+ /**
+ * Returns the model at the given location in the application package (lazily loaded),
+ *
+ * @param modelPath the full path to this model (file or directory, depending on model type)
+ * under the application package
+ * @throws IllegalArgumentException if the model cannot be loaded
+ */
+ public ImportedModel get(File modelPath) {
+ return importedModels.get(toName(modelPath));
+ }
+
+ private static String toName(File modelPath) {
+ Path localPath = Path.fromString(modelPath.toString()).getChildPath();
+ return localPath.toString().replace("/", "_").replace('.', '_');
+ }
+
+}