diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-22 14:36:03 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-22 14:36:03 +0200 |
commit | 4eb133b40206e20e3a70dae7aacec0f6b117e15d (patch) | |
tree | 0bb6fb1c306da10125efc5a5c0d6eec6deae0c22 | |
parent | 7392f9fdbee5f0a52ac9c056376b659b32500c60 (diff) |
Scope imported models to an entire application build
34 files changed, 234 insertions, 153 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index a71a0878d3d..f926259f115 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -231,9 +231,8 @@ public interface ApplicationPackage { */ ApplicationMetaData getMetaData(); - default File getFileReference(Path pathRelativeToAppDir) { - throw new UnsupportedOperationException("This application package cannot return file references"); - } + File getFileReference(Path pathRelativeToAppDir); + default void validateXML() throws IOException { throw new UnsupportedOperationException("This application package cannot validate XML"); } diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index ff6370f1738..574c25a2f84 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.config.model.deploy; -import com.google.common.collect.ImmutableMap; import com.yahoo.component.Version; import com.yahoo.component.Vtag; import com.yahoo.config.application.api.ApplicationPackage; @@ -22,8 +21,8 @@ import com.yahoo.config.provision.Zone; import com.yahoo.io.reader.NamedReader; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionBuilder; import com.yahoo.vespa.config.ConfigDefinitionKey; @@ -67,7 +66,7 @@ public class DeployState implements ConfigDefinitionStore { private final Zone zone; private final QueryProfiles queryProfiles; private final SemanticRules semanticRules; - //private final ImmutableMap<String, ImportedModel> importedMlModels; + private final ImportedModels importedModels; private final ValidationOverrides validationOverrides; private final Version wantedNodeVespaVersion; private final Instant now; @@ -101,6 +100,7 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated + this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR)); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty); this.wantedNodeVespaVersion = wantedNodeVespaVersion; @@ -215,7 +215,7 @@ public class DeployState implements ConfigDefinitionStore { public SemanticRules getSemanticRules() { return semanticRules; } /** The (machine learned) models imported from the models/ directory, as an unmodifiable map indexed by model name */ - //public Map<String, ImportedModel> importedMlModels() { return importedMlModels; } + public ImportedModels getImportedModels() { return importedModels; } public Version getWantedNodeVespaVersion() { return wantedNodeVespaVersion; } diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java index 0cfde3c655c..7404ae14a5d 100644 --- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java +++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java @@ -32,6 +32,7 @@ public class MockApplicationPackage implements ApplicationPackage { public static final String MUSIC_SEARCHDEFINITION = createSearchDefinition("music", "foo"); public static final String BOOK_SEARCHDEFINITION = createSearchDefinition("book", "bar"); + private final File root; private final String hostsS; private final String servicesS; private final List<String> searchDefinitions; @@ -42,9 +43,11 @@ public class MockApplicationPackage implements ApplicationPackage { private final QueryProfileRegistry queryProfileRegistry; private final ApplicationMetaData applicationMetaData; - protected MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir, + protected MockApplicationPackage(File root, String hosts, String services, List<String> searchDefinitions, + String searchDefinitionDir, String deploymentSpec, String validationOverrides, boolean failOnValidateXml, String queryProfile, String queryProfileType) { + this.root = root; this.hostsS = hosts; this.servicesS = services; this.searchDefinitions = searchDefinitions; @@ -57,6 +60,9 @@ public class MockApplicationPackage implements ApplicationPackage { applicationMetaData = new ApplicationMetaData("user", "dir", 0L, false, "application", "checksum", 0L, 0L); } + /** Returns the root of this application package relative to the current dir */ + protected File root() { return root; } + @Override public String getApplicationName() { return "mock application"; @@ -111,6 +117,11 @@ public class MockApplicationPackage implements ApplicationPackage { } @Override + public File getFileReference(Path path) { + return Path.fromString(root.toString()).append(path).toFile(); + } + + @Override public String getHostSource() { return "mock source"; } @@ -163,6 +174,7 @@ public class MockApplicationPackage implements ApplicationPackage { public static class Builder { + private File root = new File("nonexisting"); private String hosts = null; private String services = null; private List<String> searchDefinitions = Collections.emptyList(); @@ -176,6 +188,11 @@ public class MockApplicationPackage implements ApplicationPackage { public Builder() { } + public Builder withRoot(File root) { + this.root = root; + return this; + } + public Builder withEmptyHosts() { return this.withHosts(emptyHosts); } @@ -235,7 +252,7 @@ public class MockApplicationPackage implements ApplicationPackage { } public ApplicationPackage build() { - return new MockApplicationPackage(hosts, services, searchDefinitions, searchDefinitionDir, + return new MockApplicationPackage(root, hosts, services, searchDefinitions, searchDefinitionDir, deploymentSpec, validationOverrides, failOnValidateXml, queryProfile, queryProfileType); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 9fec1983465..2e66784527d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; @@ -682,10 +683,10 @@ public class RankProfile implements Serializable, Cloneable { * Returns a copy of this where the content is optimized for execution. * Compiled profiles should never be modified. */ - public RankProfile compile(QueryProfileRegistry queryProfiles) { + public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { try { RankProfile compiled = this.clone(); - compiled.compileThis(queryProfiles); + compiled.compileThis(queryProfiles, importedModels); return compiled; } catch (IllegalArgumentException e) { @@ -693,19 +694,19 @@ public class RankProfile implements Serializable, Cloneable { } } - private void compileThis(QueryProfileRegistry queryProfiles) { + private void compileThis(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { parseExpressions(); checkNameCollisions(getMacros(), getConstants()); ExpressionTransforms expressionTransforms = new ExpressionTransforms(); // Macro compiling first pass: compile inline macros without resolving other macros - Map<String, Macro> inlineMacros = compileMacros(getInlineMacros(), queryProfiles, Collections.emptyMap(), expressionTransforms); + Map<String, Macro> inlineMacros = compileMacros(getInlineMacros(), queryProfiles, importedModels, Collections.emptyMap(), expressionTransforms); // Macro compiling second pass: compile all macros and insert previously compiled inline macros - macros = compileMacros(getMacros(), queryProfiles, inlineMacros, expressionTransforms); + macros = compileMacros(getMacros(), queryProfiles, importedModels, inlineMacros, expressionTransforms); - firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, getConstants(), inlineMacros, expressionTransforms); - secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, getConstants(), inlineMacros, expressionTransforms); + firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms); + secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms); } private void checkNameCollisions(Map<String, Macro> macros, Map<String, Value> constants) { @@ -723,12 +724,13 @@ public class RankProfile implements Serializable, Cloneable { private Map<String, Macro> compileMacros(Map<String, Macro> macros, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Map<String, Macro> inlineMacros, ExpressionTransforms expressionTransforms) { Map<String, Macro> compiledMacros = new LinkedHashMap<>(); for (Map.Entry<String, Macro> entry : macros.entrySet()) { Macro macro = entry.getValue().clone(); - RankingExpression exp = compile(macro.getRankingExpression(), queryProfiles, getConstants(), inlineMacros, expressionTransforms); + RankingExpression exp = compile(macro.getRankingExpression(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms); macro.setRankingExpression(exp); compiledMacros.put(entry.getKey(), macro); } @@ -737,6 +739,7 @@ public class RankProfile implements Serializable, Cloneable { private RankingExpression compile(RankingExpression expression, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Map<String, Value> constants, Map<String, Macro> inlineMacros, ExpressionTransforms expressionTransforms) { @@ -745,6 +748,7 @@ public class RankProfile implements Serializable, Cloneable { RankProfileTransformContext context = new RankProfileTransformContext(this, queryProfiles, + importedModels, constants, inlineMacros, rankPropertiesOutput); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 985087e905b..4af26b72817 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -12,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.derived.validation.Validation; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import java.io.IOException; import java.io.Writer; @@ -46,8 +47,11 @@ public class DerivedConfiguration { * modified. * @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry} */ - public DerivedConfiguration(Search search, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles) { - this(search, null, new BaseDeployLogger(), rankProfileRegistry, queryProfiles); + public DerivedConfiguration(Search search, + RankProfileRegistry rankProfileRegistry, + QueryProfileRegistry queryProfiles, + ImportedModels importedModels) { + this(search, null, new BaseDeployLogger(), rankProfileRegistry, queryProfiles, importedModels); } /** @@ -63,10 +67,12 @@ public class DerivedConfiguration { * @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry} * @param queryProfiles the query profiles of this application */ - public DerivedConfiguration(Search search, List<Search> abstractSearchList, + public DerivedConfiguration(Search search, + List<Search> abstractSearchList, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfiles) { + QueryProfileRegistry queryProfiles, + ImportedModels importedModels) { Validator.ensureNotNull("Search definition", search); if ( ! search.isProcessed()) { throw new IllegalArgumentException("Search '" + search.getName() + "' not processed."); @@ -88,7 +94,7 @@ public class DerivedConfiguration { summaries = new Summaries(search, deployLogger); summaryMap = new SummaryMap(search, summaries); juniperrc = new Juniperrc(search); - rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles); + rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles, importedModels); indexingScript = new IndexingScript(search); indexInfo = new IndexInfo(search); indexSchema = new IndexSchema(search); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 77645331d9e..1e978e43d6a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -3,6 +3,7 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; @@ -26,24 +27,27 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ public RankProfileList(Search search, AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfiles) { + QueryProfileRegistry queryProfiles, + ImportedModels importedModels) { setName(search.getName()); - deriveRankProfiles(rankProfileRegistry, queryProfiles, search, attributeFields); + deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); } private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Search search, AttributeFields attributeFields) { RawRankProfile defaultProfile = new RawRankProfile(rankProfileRegistry.getRankProfile(search, "default"), queryProfiles, + importedModels, attributeFields); rankProfiles.put(defaultProfile.getName(), defaultProfile); for (RankProfile rank : rankProfileRegistry.localRankProfiles(search)) { if ("default".equals(rank.getName())) continue; - RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, attributeFields); + RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields); rankProfiles.put(rawRank.getName(), rawRank); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index b11d5962713..0d104a97698 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -9,6 +9,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; @@ -47,9 +48,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, AttributeFields attributeFields) { + public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { this.name = rankProfile.getName(); - compressedProperties = compress(removePartFromKeys(new Deriver(rankProfile, queryProfiles, attributeFields).derive())); + compressedProperties = compress(removePartFromKeys(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields).derive())); } private List<Pair<String, String>> removePartFromKeys(Map<String, String> map) { @@ -155,8 +156,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, AttributeFields attributeFields) { - RankProfile compiled = rankProfile.compile(queryProfiles); + public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { + RankProfile compiled = rankProfile.compile(queryProfiles, importedModels); attributeTypes = compiled.getAttributeTypes(); queryFeatureTypes = compiled.getQueryFeatureTypes(); deriveRankingFeatures(compiled); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 21e51b68677..a38fbe1aaa0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -77,14 +78,12 @@ public class ConvertedModel { * Create a converted model for a rank profile given from either an imported model, * or (if unavailable) from stored application package data. */ - public ConvertedModel(Path modelPath, - RankProfileTransformContext context, - ImportedModels importedModels) { + public ConvertedModel(Path modelPath, RankProfileTransformContext context) { this.modelPath = modelPath; this.modelName = toModelName(modelPath); ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); if ( store.hasSourceModel()) - expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), importedModels); + expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels()); else expressions = transformFromStoredModel(store, context.rankProfile()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java deleted file mode 100644 index 7cf0a5d8b76..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition.expressiontransforms; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter; -import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; -import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; - -import java.io.File; -import java.util.HashMap; -import java.util.Map; - -/** - * All models imported from the models/ directory in the application package - * - * @author bratseth - */ -class ImportedModels { - - /** The cache of already imported models */ - private final Map<String, ImportedModel> importedModels = new HashMap<>(); - - private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter()); - - ImportedModels() { - } - - /** - * Returns the model at the given location in the application package (lazily loaded), - * - * @param modelPath the full path to this model (file or directory, depending on model type) - * under the application package - * @throws IllegalArgumentException if the model cannot be loaded - */ - public ImportedModel get(File modelPath) { - String modelName = toName(modelPath); - ModelImporter importer = importers.stream().filter(item -> item.canImport(modelPath.toString())).findFirst().get(); - return importedModels.computeIfAbsent(modelName, __ -> importer.importModel(modelName, modelPath)); - } - - private static String toName(File modelPath) { - Path localPath = Path.fromString(modelPath.toString()).getChildPath(); - return localPath.toString().replace("/", "_").replace('.', '_'); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 0b68a67acff..36dc200f3c9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -3,13 +3,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; @@ -24,8 +24,6 @@ import java.util.Map; */ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final ImportedModels importedOnnxModels = new ImportedModels(); - /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedOnnxModels = new HashMap<>(); @@ -45,7 +43,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedOnnxModels)); + convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 5da5b3dabda..c7b4e85d74e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import java.util.Map; @@ -17,23 +18,27 @@ public class RankProfileTransformContext extends TransformContext { private final RankProfile rankProfile; private final QueryProfileRegistry queryProfiles; + private final ImportedModels importedModels; private final Map<String, RankProfile.Macro> inlineMacros; private final Map<String, String> rankPropertiesOutput; public RankProfileTransformContext(RankProfile rankProfile, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Map<String, Value> constants, Map<String, RankProfile.Macro> inlineMacros, Map<String, String> rankPropertiesOutput) { super(constants); this.rankProfile = rankProfile; this.queryProfiles = queryProfiles; + this.importedModels = importedModels; this.inlineMacros = inlineMacros; this.rankPropertiesOutput = rankPropertiesOutput; } public RankProfile rankProfile() { return rankProfile; } public QueryProfileRegistry queryProfiles() { return queryProfiles; } + public ImportedModels importedModels() { return importedModels; } public Map<String, RankProfile.Macro> inlineMacros() { return inlineMacros; } public Map<String, String> rankPropertiesOutput() { return rankPropertiesOutput; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 4f15fb5a291..619c13da764 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -2,13 +2,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; @@ -22,8 +22,6 @@ import java.util.Map; */ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final ImportedModels importedTensorFlowModels = new ImportedModels(); - /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>(); @@ -43,7 +41,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedTensorFlowModels)); + convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java index 6e0aa562508..c7762f09851 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java @@ -289,9 +289,12 @@ public class IndexedSearchCluster extends SearchCluster if ( ! (search instanceof UnproperSearch)) { DocumentDatabase db = new DocumentDatabase(this, search.getName(), - new DerivedConfiguration(search, globalSearches, deployLogger(), + new DerivedConfiguration(search, + globalSearches, + deployLogger(), getRoot().getDeployState().rankProfileRegistry(), - getRoot().getDeployState().getQueryProfiles().getRegistry())); + getRoot().getDeployState().getQueryProfiles().getRegistry(), + getRoot().getDeployState().getImportedModels())); // TODO: remove explicit adding of user configs when the complete content model is built using builders. db.mergeUserConfigs(spec.getUserConfigs()); documentDbs.add(db); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java index e87df3d530b..66553ebadf6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java @@ -104,7 +104,8 @@ public class StreamingSearchCluster extends SearchCluster implements } this.sdConfig = new DerivedConfiguration(localSearch, globalSearches, deployLogger(), getRoot().getDeployState().rankProfileRegistry(), - getRoot().getDeployState().getQueryProfiles().getRegistry()); + getRoot().getDeployState().getQueryProfiles().getRegistry(), + getRoot().getDeployState().getImportedModels()); } @Override public DerivedConfiguration getSdConfig() { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java index bff34411d44..03fa92f5cb9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.io.IOException; @@ -23,7 +24,7 @@ public class IncorrectRankingExpressionFileRefTestCase extends SearchDefinitionT Search search = SearchBuilder.buildFromFile("src/test/examples/incorrectrankingexpressionfileref.sd", registry, new QueryProfileRegistry()); - new DerivedConfiguration(search, registry, new QueryProfileRegistry()); // cause rank profile parsing + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing fail("parsing should have failed"); } catch (IllegalArgumentException e) { e.printStackTrace(); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index ed612039fb7..de9df08f5c0 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -16,6 +16,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; @@ -91,7 +92,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertEquals(8, rankProfile.getNumThreadsPerSearch()); assertEquals(70, rankProfile.getMinHitsPerThread()); assertEquals(1200, rankProfile.getNumSearchPartitions()); - RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), attributeFields); + RawRankProfile rawRankProfile = new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), attributeFields); assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").isPresent()); assertEquals("0.78", findProperty(rawRankProfile.configProperties(), "vespa.matching.termwise_limit").get()); assertTrue(findProperty(rawRankProfile.configProperties(), "vespa.matching.numthreadspersearch").isPresent()); @@ -126,7 +127,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { } private static void assertAttributeTypeSettings(RankProfile profile, Search search) { - RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new AttributeFields(search)); + RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)); assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.a").get()); assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.attribute.b").get()); assertEquals("tensor(x[])", findProperty(rawProfile.configProperties(), "vespa.type.attribute.c").get()); @@ -168,7 +169,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { } private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) { - RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new AttributeFields(search)); + RawRankProfile rawProfile = new RawRankProfile(profile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)); assertEquals("tensor(x[10])", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor1").get()); assertEquals("tensor(y{})", findProperty(rawProfile.configProperties(), "vespa.type.query.tensor2").get()); assertFalse(findProperty(rawProfile.configProperties(), "vespa.type.query.tensor3").isPresent()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java index 15ddef60807..3a2482b56d0 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java @@ -7,6 +7,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.util.ArrayList; @@ -60,7 +61,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { assertEquals("query(a) = 1500", parent.getRankProperties().get(0).toString()); // Check derived model - RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), attributeFields); + RawRankProfile rawParent = new RawRankProfile(parent, new QueryProfileRegistry(), new ImportedModels(), attributeFields); assertEquals("(query(a),1500)", rawParent.configProperties().get(0).toString()); } @@ -72,6 +73,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { // Check derived model RawRankProfile rawChild = new RawRankProfile(rankProfileRegistry.getRankProfile(search, "child"), new QueryProfileRegistry(), + new ImportedModels(), attributeFields); assertEquals("(query(a),2000)", rawChild.configProperties().get(0).toString()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index 82b9f5ac043..da546967dc1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; @@ -69,18 +70,19 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(queryProfileRegistry); + RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(queryProfileRegistry, new ImportedModels()); assertEquals("0.0", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child1 = rankProfileRegistry.getRankProfile(s, "child1").compile(queryProfileRegistry); + RankProfile child1 = rankProfileRegistry.getRankProfile(s, "child1").compile(queryProfileRegistry, new ImportedModels()); assertEquals("6.5", child1.getFirstPhaseRanking().getRoot().toString()); assertEquals("11.5", child1.getSecondPhaseRanking().getRoot().toString()); - RankProfile child2 = rankProfileRegistry.getRankProfile(s, "child2").compile(queryProfileRegistry); + RankProfile child2 = rankProfileRegistry.getRankProfile(s, "child2").compile(queryProfileRegistry, new ImportedModels()); assertEquals("16.6", child2.getFirstPhaseRanking().getRoot().toString()); assertEquals("foo: 14.0", child2.getMacros().get("foo").getRankingExpression().toString()); List<Pair<String, String>> rankProperties = new RawRankProfile(child2, queryProfileRegistry, + new ImportedModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString()); assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString()); @@ -111,7 +113,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); try { - rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); fail("Should have caused an exception"); } catch (IllegalArgumentException e) { @@ -174,7 +176,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase profile.parseExpressions(); // TODO: Do differently assertEquals("safeLog(popShareSlowDecaySignal,myValue)", profile.getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString()); assertEquals("safeLog(popShareSlowDecaySignal,-9.21034037)", - profile.compile(new QueryProfileRegistry()).getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString()); } @Test @@ -197,7 +199,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase Search s = builder.getSearch(); RankProfile profile = rankProfileRegistry.getRankProfile(s, "test"); assertEquals("k1 + (k2 + k3) / 100000000.0", - profile.compile(new QueryProfileRegistry()).getMacros().get("rank_default").getRankingExpression().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("rank_default").getRankingExpression().getRoot().toString()); } @Test @@ -223,7 +225,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase Search s = builder.getSearch(); RankProfile profile = rankProfileRegistry.getRankProfile(s, "test"); assertEquals("0.5 + 50 * (attribute(rating_yelp) - 3)", - profile.compile(new QueryProfileRegistry()).getMacros().get("rank_default").getRankingExpression().getRoot().toString()); + profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("rank_default").getRankingExpression().getRoot().toString()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index a0dd18aeea9..555aa698c65 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.util.Optional; @@ -61,10 +62,10 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry()); + RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry()); + RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("7.0 * (9 + attribute(a))", child.getFirstPhaseRanking().getRoot().toString()); } @@ -121,14 +122,14 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry()); + RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("17.0", parent.getFirstPhaseRanking().getRoot().toString()); assertEquals("0.0", parent.getSecondPhaseRanking().getRoot().toString()); assertEquals("10.0", getRankingExpression("foo", parent, s)); assertEquals("17.0", getRankingExpression("firstphase", parent, s)); assertEquals("0.0", getRankingExpression("secondphase", parent, s)); - RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry()); + RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("31.0 + bar + arg(4.0)", child.getFirstPhaseRanking().getRoot().toString()); assertEquals("24.0", child.getSecondPhaseRanking().getRoot().toString()); assertEquals("12.0", getRankingExpression("foo", child, s)); @@ -177,7 +178,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("attribute(a) + C + (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString()); assertEquals("attribute(a) + attribute(b)", getRankingExpression("C", test, s)); assertEquals("attribute(b) + 1", getRankingExpression("D", test, s)); @@ -208,7 +209,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase private String getRankingExpression(String name, RankProfile rankProfile, Search search) { Optional<String> rankExpression = - new RawRankProfile(rankProfile, new QueryProfileRegistry(), new AttributeFields(search)) + new RawRankProfile(rankProfile, new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)) .configProperties() .stream() .filter(r -> r.getFirst().equals("rankingExpression(" + name + ").rankingScript")) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index ed1b00e2875..2bd7d3031a5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Ignore; import org.junit.Test; @@ -44,9 +45,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), + new ImportedModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(sin).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -87,9 +89,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), + new ImportedModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(tan).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -136,9 +139,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), + new ImportedModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(sin).rankingScript,x * x)", testRankProperties.get(0).toString()); @@ -199,9 +203,10 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles, new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, + new ImportedModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java index a07fea69592..3fe3a7c3de1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Ignore; import org.junit.Test; @@ -24,7 +25,7 @@ public class RankingExpressionValidationTestCase extends SearchDefinitionTestCas try { RankProfileRegistry registry = new RankProfileRegistry(); Search search = importWithExpression(expression, registry); - new DerivedConfiguration(search, registry, new QueryProfileRegistry()); // cause rank profile parsing + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // cause rank profile parsing fail("No exception on incorrect ranking expression " + expression); } catch (IllegalArgumentException e) { // Success diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java index 3d2bce62713..88a02cc7a93 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.configmodel.producers.DocumentManager; import com.yahoo.vespa.configmodel.producers.DocumentTypes; @@ -80,14 +81,16 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase protected DerivedConfiguration derive(String dirName, String searchDefinitionName, SearchBuilder builder) throws IOException { DerivedConfiguration config = new DerivedConfiguration(builder.getSearch(searchDefinitionName), builder.getRankProfileRegistry(), - builder.getQueryProfileRegistry()); + builder.getQueryProfileRegistry(), + new ImportedModels()); return export(dirName, builder, config); } protected DerivedConfiguration derive(String dirName, SearchBuilder builder, Search search) throws IOException { DerivedConfiguration config = new DerivedConfiguration(search, builder.getRankProfileRegistry(), - builder.getQueryProfileRegistry()); + builder.getQueryProfileRegistry(), + new ImportedModels()); return export(dirName, builder, config); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java index 21467776ad9..f4344c9b03c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java @@ -10,6 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.io.IOException; @@ -34,7 +35,7 @@ public class EmptyRankProfileTestCase extends SearchDefinitionTestCase { doc.addField(new SDField("c", DataType.STRING)); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry()); + new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java index 59d2d9879d1..dec4b734f27 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java @@ -11,6 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; @@ -41,7 +42,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); Processing.process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true); - DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry()); + DerivedConfiguration derived=new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); // Check attribute fields derived.getAttributeFields(); // TODO: assert content @@ -72,7 +73,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry()); + DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(),new ImportedModels()); // Check il script addition assertIndexing(Arrays.asList("clear_state | guard { input a | tokenize normalize stem:\"SHORTEST\" | index a; }", @@ -99,7 +100,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { field2.setLiteralBoost(20); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); - new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry()); + new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); assertIndexing(Arrays.asList("clear_state | guard { input title | tokenize normalize stem:\"SHORTEST\" | summary title | index title; }", "clear_state | guard { input body | tokenize normalize stem:\"SHORTEST\" | summary body | index body; }", "clear_state | guard { input title | tokenize | index title_literal; }", diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java index f4edc1dd0ae..723cd58a34a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java @@ -5,6 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.io.File; @@ -34,7 +35,8 @@ public class SimpleInheritTestCase extends AbstractExportingTestCase { DerivedConfiguration config = new DerivedConfiguration(search, builder.getRankProfileRegistry(), - new QueryProfileRegistry()); + new QueryProfileRegistry(), + new ImportedModels()); config.export(toDirName); checkDir(toDirName, expectedResultsDirName); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java index fd029c1df05..26b100a2d96 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java @@ -10,6 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; @@ -33,7 +34,7 @@ public class TypeConversionTestCase extends SearchDefinitionTestCase { document.addField(a); Processing.process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true); - DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry()); + DerivedConfiguration derived = new DerivedConfiguration(search, rankProfileRegistry, new QueryProfileRegistry(), new ImportedModels()); IndexInfo indexInfo = derived.getIndexInfo(); assertFalse(indexInfo.hasCommand("default", "compact-to-term")); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java index d743f60201e..d38bce04617 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.derived.Deriver; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Ignore; import org.junit.Test; import java.io.File; @@ -99,7 +100,7 @@ public class ImplicitSearchFieldsTestCase extends SearchDefinitionTestCase { sb.importFile("src/test/examples/nextgen/simple.sd"); sb.build(); assertNotNull(sb.getSearch()); - new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry()); + new DerivedConfiguration(sb.getSearch(), sb.getRankProfileRegistry(), new QueryProfileRegistry(), new ImportedModels()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index ab689b88993..45cdbfa9c1f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -3,12 +3,14 @@ package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; +import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import java.util.HashMap; import java.util.List; @@ -76,7 +78,11 @@ class RankProfileSearchFixture { } public RankProfile compileRankProfile(String rankProfile) { - RankProfile compiled = rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry); + return compileRankProfile(rankProfile, Path.fromString("nonexistinng")); + } + + public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { + RankProfile compiled = rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index a311a2ed706..90137ddde49 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -127,7 +127,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: onnx('mnist_softmax.onnx')" + " }\n" + " }"); - search.compileRankProfile("my_profile"); + search.compileRankProfile("my_profile", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -226,7 +226,7 @@ public class RankingExpressionWithOnnxTestCase { String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_Variable, f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); - search.compileRankProfile("my_profile"); + search.compileRankProfile("my_profile", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", @@ -241,7 +241,7 @@ public class RankingExpressionWithOnnxTestCase { storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication); - searchFromStored.compileRankProfile("my_profile"); + searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", searchFromStored.search().getRankingConstants().get( name + "_Variable")); @@ -317,7 +317,7 @@ public class RankingExpressionWithOnnxTestCase { " }", constant, field); - fixture.compileRankProfile("my_profile"); + fixture.compileRankProfile("my_profile", applicationDir.append("models")); return fixture; } catch (ParseException e) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index a3a286350b5..2804b92767a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -163,7 +163,7 @@ public class RankingExpressionWithTensorFlowTestCase { " expression: tensorflow('mnist_softmax/saved')" + " }\n" + " }"); - search.compileRankProfile("my_profile"); + search.compileRankProfile("my_profile", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -280,8 +280,8 @@ public class RankingExpressionWithTensorFlowTestCase { String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(" + name + "_layer_Variable_1_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir)); - search.compileRankProfile("my_profile"); - search.compileRankProfile("my_profile_child"); + search.compileRankProfile("my_profile", applicationDir.append("models")); + search.compileRankProfile("my_profile_child", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); @@ -297,8 +297,8 @@ public class RankingExpressionWithTensorFlowTestCase { storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication); - searchFromStored.compileRankProfile("my_profile"); - searchFromStored.compileRankProfile("my_profile_child"); + searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); + searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", @@ -358,8 +358,8 @@ public class RankingExpressionWithTensorFlowTestCase { final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_" + name + "_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir)); - search.compileRankProfile("my_profile"); - search.compileRankProfile("my_profile_child"); + search.compileRankProfile("my_profile", applicationDir.append("models")); + search.compileRankProfile("my_profile_child", applicationDir.append("models")); search.assertFirstPhaseExpression(expression, "my_profile"); search.assertFirstPhaseExpression(expression, "my_profile_child"); assertSmallConstant(name + "_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); @@ -376,8 +376,8 @@ public class RankingExpressionWithTensorFlowTestCase { storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication); - searchFromStored.compileRankProfile("my_profile"); - searchFromStored.compileRankProfile("my_profile_child"); + searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); + searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child"); assertSmallConstant(name + "_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); @@ -453,7 +453,7 @@ public class RankingExpressionWithTensorFlowTestCase { " }", constant, field); - fixture.compileRankProfile("my_profile"); + fixture.compileRankProfile("my_profile", applicationDir.append("models")); return fixture; } catch (ParseException e) { @@ -473,26 +473,19 @@ public class RankingExpressionWithTensorFlowTestCase { static class StoringApplicationPackage extends MockApplicationPackage { - private final File root; - StoringApplicationPackage(Path applicationPackageWritableRoot) { this(applicationPackageWritableRoot, null, null); } StoringApplicationPackage(Path applicationPackageWritableRoot, String queryProfile, String queryProfileType) { - super(null, null, Collections.emptyList(), null, + super(new File(applicationPackageWritableRoot.toString()), + null, null, Collections.emptyList(), null, null, null, false, queryProfile, queryProfileType); - this.root = new File(applicationPackageWritableRoot.toString()); - } - - @Override - public File getFileReference(Path path) { - return Path.fromString(root.toString()).append(path).toFile(); } @Override public ApplicationFile getFile(Path file) { - return new StoringApplicationPackageFile(file, Path.fromString(root.toString())); + return new StoringApplicationPackageFile(file, Path.fromString(root().toString())); } @Override diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java index f98783ad671..2e109553560 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java @@ -46,7 +46,7 @@ public class RankingExpressionWithXgboostTestCase { " }", constant, field); - fixture.compileRankProfile("my_profile"); + fixture.compileRankProfile("my_profile", applicationDir); return fixture; } catch (ParseException e) { throw new IllegalArgumentException(e); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index 18c3e43ae7e..31ceb97ab50 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.io.IOException; @@ -39,6 +40,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { List<Pair<String, String>> rankProperties = new RawRankProfile(macrosRankProfile, new QueryProfileRegistry(), + new ImportedModels(), new AttributeFields(search)).configProperties(); assertEquals(6, rankProperties.size()); @@ -64,7 +66,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressioninfile", registry, new QueryProfileRegistry()).getSearch(); - new DerivedConfiguration(search, registry, new QueryProfileRegistry()); // rank profile parsing happens during deriving + new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedModels()); // rank profile parsing happens during deriving } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 2b2d72dcf34..6b287c77a10 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,6 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.util.List; @@ -199,9 +200,10 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { "}\n"); builder.build(true, new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles, new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, + new ImportedModels(), new AttributeFields(s)).configProperties(); return testRankProperties; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java new file mode 100644 index 00000000000..f8bc83735c7 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -0,0 +1,70 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.yahoo.path.Path; + +import java.io.File; +import java.util.Optional; + +/** + * All models imported from the models/ directory in the application package + * + * @author bratseth + */ +public class ImportedModels { + + /** All imported models, indexed by their names */ + private final ImmutableMap<String, ImportedModel> importedModels; + + private static final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter()); + + /** Create a null imported models */ + public ImportedModels() { + importedModels = ImmutableMap.of(); + } + + public ImportedModels(File modelsDirectory) { + ImmutableMap.Builder<String, ImportedModel> builder = new ImmutableMap.Builder<>(); + + // Find all subdirectories recursively which contains a model we can read + importRecursively(modelsDirectory, builder); + importedModels = builder.build(); + } + + private static void importRecursively(File dir, ImmutableMap.Builder<String, ImportedModel> builder) { + if ( ! dir.isDirectory()) return; + for (File child : dir.listFiles()) { + Optional<ModelImporter> importer = findImporterOf(child); + if (importer.isPresent()) { + String name = toName(child); + builder.put(name, importer.get().importModel(name, child)); + } + else { + importRecursively(child, builder); + } + } + } + + private static Optional<ModelImporter> findImporterOf(File path) { + return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); + } + + /** + * Returns the model at the given location in the application package (lazily loaded), + * + * @param modelPath the full path to this model (file or directory, depending on model type) + * under the application package + * @throws IllegalArgumentException if the model cannot be loaded + */ + public ImportedModel get(File modelPath) { + return importedModels.get(toName(modelPath)); + } + + private static String toName(File modelPath) { + Path localPath = Path.fromString(modelPath.toString()).getChildPath(); + return localPath.toString().replace("/", "_").replace('.', '_'); + } + +} |