diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-22 14:36:03 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-22 14:36:03 +0200 |
commit | 4eb133b40206e20e3a70dae7aacec0f6b117e15d (patch) | |
tree | 0bb6fb1c306da10125efc5a5c0d6eec6deae0c22 /config-model/src/main/java/com | |
parent | 7392f9fdbee5f0a52ac9c056376b659b32500c60 (diff) |
Scope imported models to an entire application build
Diffstat (limited to 'config-model/src/main/java/com')
13 files changed, 77 insertions, 90 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index ff6370f1738..574c25a2f84 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.config.model.deploy; -import com.google.common.collect.ImmutableMap; import com.yahoo.component.Version; import com.yahoo.component.Vtag; import com.yahoo.config.application.api.ApplicationPackage; @@ -22,8 +21,8 @@ import com.yahoo.config.provision.Zone; import com.yahoo.io.reader.NamedReader; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionBuilder; import com.yahoo.vespa.config.ConfigDefinitionKey; @@ -67,7 +66,7 @@ public class DeployState implements ConfigDefinitionStore { private final Zone zone; private final QueryProfiles queryProfiles; private final SemanticRules semanticRules; - //private final ImmutableMap<String, ImportedModel> importedMlModels; + private final ImportedModels importedModels; private final ValidationOverrides validationOverrides; private final Version wantedNodeVespaVersion; private final Instant now; @@ -101,6 +100,7 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated + this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR)); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty); this.wantedNodeVespaVersion = wantedNodeVespaVersion; @@ -215,7 +215,7 @@ public class DeployState implements ConfigDefinitionStore { public SemanticRules getSemanticRules() { return semanticRules; } /** The (machine learned) models imported from the models/ directory, as an unmodifiable map indexed by model name */ - //public Map<String, ImportedModel> importedMlModels() { return importedMlModels; } + public ImportedModels getImportedModels() { return importedModels; } public Version getWantedNodeVespaVersion() { return wantedNodeVespaVersion; } diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java index 0cfde3c655c..7404ae14a5d 100644 --- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java +++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java @@ -32,6 +32,7 @@ public class MockApplicationPackage implements ApplicationPackage { public static final String MUSIC_SEARCHDEFINITION = createSearchDefinition("music", "foo"); public static final String BOOK_SEARCHDEFINITION = createSearchDefinition("book", "bar"); + private final File root; private final String hostsS; private final String servicesS; private final List<String> searchDefinitions; @@ -42,9 +43,11 @@ public class MockApplicationPackage implements ApplicationPackage { private final QueryProfileRegistry queryProfileRegistry; private final ApplicationMetaData applicationMetaData; - protected MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir, + protected MockApplicationPackage(File root, String hosts, String services, List<String> searchDefinitions, + String searchDefinitionDir, String deploymentSpec, String validationOverrides, boolean failOnValidateXml, String queryProfile, String queryProfileType) { + this.root = root; this.hostsS = hosts; this.servicesS = services; this.searchDefinitions = searchDefinitions; @@ -57,6 +60,9 @@ public class MockApplicationPackage implements ApplicationPackage { applicationMetaData = new ApplicationMetaData("user", "dir", 0L, false, "application", "checksum", 0L, 0L); } + /** Returns the root of this application package relative to the current dir */ + protected File root() { return root; } + @Override public String getApplicationName() { return "mock application"; @@ -111,6 +117,11 @@ public class MockApplicationPackage implements ApplicationPackage { } @Override + public File getFileReference(Path path) { + return Path.fromString(root.toString()).append(path).toFile(); + } + + @Override public String getHostSource() { return "mock source"; } @@ -163,6 +174,7 @@ public class MockApplicationPackage implements ApplicationPackage { public static class Builder { + private File root = new File("nonexisting"); private String hosts = null; private String services = null; private List<String> searchDefinitions = Collections.emptyList(); @@ -176,6 +188,11 @@ public class MockApplicationPackage implements ApplicationPackage { public Builder() { } + public Builder withRoot(File root) { + this.root = root; + return this; + } + public Builder withEmptyHosts() { return this.withHosts(emptyHosts); } @@ -235,7 +252,7 @@ public class MockApplicationPackage implements ApplicationPackage { } public ApplicationPackage build() { - return new MockApplicationPackage(hosts, services, searchDefinitions, searchDefinitionDir, + return new MockApplicationPackage(root, hosts, services, searchDefinitions, searchDefinitionDir, deploymentSpec, validationOverrides, failOnValidateXml, queryProfile, queryProfileType); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 9fec1983465..2e66784527d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; @@ -682,10 +683,10 @@ public class RankProfile implements Serializable, Cloneable { * Returns a copy of this where the content is optimized for execution. * Compiled profiles should never be modified. */ - public RankProfile compile(QueryProfileRegistry queryProfiles) { + public RankProfile compile(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { try { RankProfile compiled = this.clone(); - compiled.compileThis(queryProfiles); + compiled.compileThis(queryProfiles, importedModels); return compiled; } catch (IllegalArgumentException e) { @@ -693,19 +694,19 @@ public class RankProfile implements Serializable, Cloneable { } } - private void compileThis(QueryProfileRegistry queryProfiles) { + private void compileThis(QueryProfileRegistry queryProfiles, ImportedModels importedModels) { parseExpressions(); checkNameCollisions(getMacros(), getConstants()); ExpressionTransforms expressionTransforms = new ExpressionTransforms(); // Macro compiling first pass: compile inline macros without resolving other macros - Map<String, Macro> inlineMacros = compileMacros(getInlineMacros(), queryProfiles, Collections.emptyMap(), expressionTransforms); + Map<String, Macro> inlineMacros = compileMacros(getInlineMacros(), queryProfiles, importedModels, Collections.emptyMap(), expressionTransforms); // Macro compiling second pass: compile all macros and insert previously compiled inline macros - macros = compileMacros(getMacros(), queryProfiles, inlineMacros, expressionTransforms); + macros = compileMacros(getMacros(), queryProfiles, importedModels, inlineMacros, expressionTransforms); - firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, getConstants(), inlineMacros, expressionTransforms); - secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, getConstants(), inlineMacros, expressionTransforms); + firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms); + secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms); } private void checkNameCollisions(Map<String, Macro> macros, Map<String, Value> constants) { @@ -723,12 +724,13 @@ public class RankProfile implements Serializable, Cloneable { private Map<String, Macro> compileMacros(Map<String, Macro> macros, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Map<String, Macro> inlineMacros, ExpressionTransforms expressionTransforms) { Map<String, Macro> compiledMacros = new LinkedHashMap<>(); for (Map.Entry<String, Macro> entry : macros.entrySet()) { Macro macro = entry.getValue().clone(); - RankingExpression exp = compile(macro.getRankingExpression(), queryProfiles, getConstants(), inlineMacros, expressionTransforms); + RankingExpression exp = compile(macro.getRankingExpression(), queryProfiles, importedModels, getConstants(), inlineMacros, expressionTransforms); macro.setRankingExpression(exp); compiledMacros.put(entry.getKey(), macro); } @@ -737,6 +739,7 @@ public class RankProfile implements Serializable, Cloneable { private RankingExpression compile(RankingExpression expression, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Map<String, Value> constants, Map<String, Macro> inlineMacros, ExpressionTransforms expressionTransforms) { @@ -745,6 +748,7 @@ public class RankProfile implements Serializable, Cloneable { RankProfileTransformContext context = new RankProfileTransformContext(this, queryProfiles, + importedModels, constants, inlineMacros, rankPropertiesOutput); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 985087e905b..4af26b72817 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -12,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.derived.validation.Validation; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import java.io.IOException; import java.io.Writer; @@ -46,8 +47,11 @@ public class DerivedConfiguration { * modified. * @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry} */ - public DerivedConfiguration(Search search, RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles) { - this(search, null, new BaseDeployLogger(), rankProfileRegistry, queryProfiles); + public DerivedConfiguration(Search search, + RankProfileRegistry rankProfileRegistry, + QueryProfileRegistry queryProfiles, + ImportedModels importedModels) { + this(search, null, new BaseDeployLogger(), rankProfileRegistry, queryProfiles, importedModels); } /** @@ -63,10 +67,12 @@ public class DerivedConfiguration { * @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry} * @param queryProfiles the query profiles of this application */ - public DerivedConfiguration(Search search, List<Search> abstractSearchList, + public DerivedConfiguration(Search search, + List<Search> abstractSearchList, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfiles) { + QueryProfileRegistry queryProfiles, + ImportedModels importedModels) { Validator.ensureNotNull("Search definition", search); if ( ! search.isProcessed()) { throw new IllegalArgumentException("Search '" + search.getName() + "' not processed."); @@ -88,7 +94,7 @@ public class DerivedConfiguration { summaries = new Summaries(search, deployLogger); summaryMap = new SummaryMap(search, summaries); juniperrc = new Juniperrc(search); - rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles); + rankProfileList = new RankProfileList(search, attributeFields, rankProfileRegistry, queryProfiles, importedModels); indexingScript = new IndexingScript(search); indexInfo = new IndexInfo(search); indexSchema = new IndexSchema(search); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 77645331d9e..1e978e43d6a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -3,6 +3,7 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; @@ -26,24 +27,27 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ public RankProfileList(Search search, AttributeFields attributeFields, RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfiles) { + QueryProfileRegistry queryProfiles, + ImportedModels importedModels) { setName(search.getName()); - deriveRankProfiles(rankProfileRegistry, queryProfiles, search, attributeFields); + deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); } private void deriveRankProfiles(RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Search search, AttributeFields attributeFields) { RawRankProfile defaultProfile = new RawRankProfile(rankProfileRegistry.getRankProfile(search, "default"), queryProfiles, + importedModels, attributeFields); rankProfiles.put(defaultProfile.getName(), defaultProfile); for (RankProfile rank : rankProfileRegistry.localRankProfiles(search)) { if ("default".equals(rank.getName())) continue; - RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, attributeFields); + RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields); rankProfiles.put(rawRank.getName(), rawRank); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index b11d5962713..0d104a97698 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -9,6 +9,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; @@ -47,9 +48,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, AttributeFields attributeFields) { + public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { this.name = rankProfile.getName(); - compressedProperties = compress(removePartFromKeys(new Deriver(rankProfile, queryProfiles, attributeFields).derive())); + compressedProperties = compress(removePartFromKeys(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields).derive())); } private List<Pair<String, String>> removePartFromKeys(Map<String, String> map) { @@ -155,8 +156,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, AttributeFields attributeFields) { - RankProfile compiled = rankProfile.compile(queryProfiles); + public Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedModels importedModels, AttributeFields attributeFields) { + RankProfile compiled = rankProfile.compile(queryProfiles, importedModels); attributeTypes = compiled.getAttributeTypes(); queryFeatureTypes = compiled.getQueryFeatureTypes(); deriveRankingFeatures(compiled); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 21e51b68677..a38fbe1aaa0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -77,14 +78,12 @@ public class ConvertedModel { * Create a converted model for a rank profile given from either an imported model, * or (if unavailable) from stored application package data. */ - public ConvertedModel(Path modelPath, - RankProfileTransformContext context, - ImportedModels importedModels) { + public ConvertedModel(Path modelPath, RankProfileTransformContext context) { this.modelPath = modelPath; this.modelName = toModelName(modelPath); ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); if ( store.hasSourceModel()) - expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), importedModels); + expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels()); else expressions = transformFromStoredModel(store, context.rankProfile()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java deleted file mode 100644 index 7cf0a5d8b76..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition.expressiontransforms; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter; -import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; -import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; - -import java.io.File; -import java.util.HashMap; -import java.util.Map; - -/** - * All models imported from the models/ directory in the application package - * - * @author bratseth - */ -class ImportedModels { - - /** The cache of already imported models */ - private final Map<String, ImportedModel> importedModels = new HashMap<>(); - - private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter()); - - ImportedModels() { - } - - /** - * Returns the model at the given location in the application package (lazily loaded), - * - * @param modelPath the full path to this model (file or directory, depending on model type) - * under the application package - * @throws IllegalArgumentException if the model cannot be loaded - */ - public ImportedModel get(File modelPath) { - String modelName = toName(modelPath); - ModelImporter importer = importers.stream().filter(item -> item.canImport(modelPath.toString())).findFirst().get(); - return importedModels.computeIfAbsent(modelName, __ -> importer.importModel(modelName, modelPath)); - } - - private static String toName(File modelPath) { - Path localPath = Path.fromString(modelPath.toString()).getChildPath(); - return localPath.toString().replace("/", "_").replace('.', '_'); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 0b68a67acff..36dc200f3c9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -3,13 +3,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; @@ -24,8 +24,6 @@ import java.util.Map; */ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final ImportedModels importedOnnxModels = new ImportedModels(); - /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedOnnxModels = new HashMap<>(); @@ -45,7 +43,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedOnnxModels)); + convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 5da5b3dabda..c7b4e85d74e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import java.util.Map; @@ -17,23 +18,27 @@ public class RankProfileTransformContext extends TransformContext { private final RankProfile rankProfile; private final QueryProfileRegistry queryProfiles; + private final ImportedModels importedModels; private final Map<String, RankProfile.Macro> inlineMacros; private final Map<String, String> rankPropertiesOutput; public RankProfileTransformContext(RankProfile rankProfile, QueryProfileRegistry queryProfiles, + ImportedModels importedModels, Map<String, Value> constants, Map<String, RankProfile.Macro> inlineMacros, Map<String, String> rankPropertiesOutput) { super(constants); this.rankProfile = rankProfile; this.queryProfiles = queryProfiles; + this.importedModels = importedModels; this.inlineMacros = inlineMacros; this.rankPropertiesOutput = rankPropertiesOutput; } public RankProfile rankProfile() { return rankProfile; } public QueryProfileRegistry queryProfiles() { return queryProfiles; } + public ImportedModels importedModels() { return importedModels; } public Map<String, RankProfile.Macro> inlineMacros() { return inlineMacros; } public Map<String, String> rankPropertiesOutput() { return rankPropertiesOutput; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 4f15fb5a291..619c13da764 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -2,13 +2,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import java.io.File; import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; @@ -22,8 +22,6 @@ import java.util.Map; */ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final ImportedModels importedTensorFlowModels = new ImportedModels(); - /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>(); @@ -43,7 +41,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedTensorFlowModels)); + convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java index 6e0aa562508..c7762f09851 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java @@ -289,9 +289,12 @@ public class IndexedSearchCluster extends SearchCluster if ( ! (search instanceof UnproperSearch)) { DocumentDatabase db = new DocumentDatabase(this, search.getName(), - new DerivedConfiguration(search, globalSearches, deployLogger(), + new DerivedConfiguration(search, + globalSearches, + deployLogger(), getRoot().getDeployState().rankProfileRegistry(), - getRoot().getDeployState().getQueryProfiles().getRegistry())); + getRoot().getDeployState().getQueryProfiles().getRegistry(), + getRoot().getDeployState().getImportedModels())); // TODO: remove explicit adding of user configs when the complete content model is built using builders. db.mergeUserConfigs(spec.getUserConfigs()); documentDbs.add(db); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java index e87df3d530b..66553ebadf6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/StreamingSearchCluster.java @@ -104,7 +104,8 @@ public class StreamingSearchCluster extends SearchCluster implements } this.sdConfig = new DerivedConfiguration(localSearch, globalSearches, deployLogger(), getRoot().getDeployState().rankProfileRegistry(), - getRoot().getDeployState().getQueryProfiles().getRegistry()); + getRoot().getDeployState().getQueryProfiles().getRegistry(), + getRoot().getDeployState().getImportedModels()); } @Override public DerivedConfiguration getSdConfig() { |