diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-25 12:33:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-25 12:33:10 +0200 |
commit | 011d94ba6809d9c931a9e3d8d8bbcf9a28a97a61 (patch) | |
tree | 1cf090b79ef73ec1cd6e9799642ea62c86377578 | |
parent | ca44e13502f4d6b7efe1ba327973900f9b8e0f44 (diff) | |
parent | ed3923c95484e57e4eac43c25e985512ca3aa645 (diff) |
Merge pull request #6672 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-6
Bratseth/generate rank profiles for all models part 6
72 files changed, 5732 insertions, 372 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/ApplicationConfigProducerRoot.java b/config-model/src/main/java/com/yahoo/config/model/ApplicationConfigProducerRoot.java index 912968747df..7a5969585e9 100644 --- a/config-model/src/main/java/com/yahoo/config/model/ApplicationConfigProducerRoot.java +++ b/config-model/src/main/java/com/yahoo/config/model/ApplicationConfigProducerRoot.java @@ -125,11 +125,11 @@ public class ApplicationConfigProducerRoot extends AbstractConfigProducer<Abstra } // TODO: Do this as another config model depending on the other models - public void setupRouting(ConfigModelRepo configModels) { + public void setupRouting(VespaModel vespaModel, ConfigModelRepo configModels) { if (admin != null) { Routing routing = configModels.getRouting(); if (routing == null) { - routing = new Routing(ConfigModelContext.create(configModels, this, "routing")); + routing = new Routing(ConfigModelContext.create(vespaModel, configModels, this, "routing")); configModels.add(routing); } this.routing = routing; diff --git a/config-model/src/main/java/com/yahoo/config/model/ConfigModelContext.java b/config-model/src/main/java/com/yahoo/config/model/ConfigModelContext.java index 78a8c161c3b..eea0b16b8e1 100644 --- a/config-model/src/main/java/com/yahoo/config/model/ConfigModelContext.java +++ b/config-model/src/main/java/com/yahoo/config/model/ConfigModelContext.java @@ -5,6 +5,7 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.producer.AbstractConfigProducer; +import com.yahoo.vespa.model.VespaModel; import java.util.stream.Stream; @@ -19,16 +20,18 @@ public final class ConfigModelContext { private final AbstractConfigProducer parent; private final String producerId; private final DeployState deployState; + private final VespaModel vespaModel; private final ConfigModelRepoAdder configModelRepoAdder; private final ApplicationType applicationType; private ConfigModelContext(ApplicationType applicationType, DeployState deployState, - ConfigModelRepoAdder configModelRepoAdder, + VespaModel vespaModel, ConfigModelRepoAdder configModelRepoAdder, AbstractConfigProducer parent, String producerId) { this.applicationType = applicationType; this.deployState = deployState; + this.vespaModel = vespaModel; this.configModelRepoAdder = configModelRepoAdder; this.parent = parent; this.producerId = producerId; @@ -40,18 +43,23 @@ public final class ConfigModelContext { public DeployLogger getDeployLogger() { return deployState.getDeployLogger(); } public DeployState getDeployState() { return deployState; } public ApplicationType getApplicationType() { return applicationType; } + public VespaModel vespaModel() { return vespaModel; } /** Returns write access to the config model repo, or null (only) if this is improperly initialized during testing */ public ConfigModelRepoAdder getConfigModelRepoAdder() { return configModelRepoAdder; } /** Create a new context with a different parent */ public ConfigModelContext withParent(AbstractConfigProducer newParent) { - return ConfigModelContext.create(deployState, configModelRepoAdder, newParent, producerId); + return ConfigModelContext.create(deployState, vespaModel, configModelRepoAdder, newParent, producerId); } /** Create a new context with a different config model producer id */ public ConfigModelContext withId(String producerId) { - return ConfigModelContext.create(deployState, configModelRepoAdder, parent, producerId); + return ConfigModelContext.create(deployState, vespaModel, configModelRepoAdder, parent, producerId); + } + + public ConfigModelContext with(VespaModel vespaModel) { + return ConfigModelContext.create(deployState, vespaModel, configModelRepoAdder, parent, producerId); } /** @@ -61,9 +69,9 @@ public final class ConfigModelContext { * @param producerId the id to be used for the config model. * @return a model context that can be passed to a model. */ - public static ConfigModelContext create(ConfigModelRepoAdder configModelRepoAdder, + public static ConfigModelContext create(VespaModel vespaModel, ConfigModelRepoAdder configModelRepoAdder, AbstractConfigProducer parent, String producerId) { - return create(parent.getRoot().getDeployState(), configModelRepoAdder, parent, producerId); + return create(parent.getRoot().getDeployState(), vespaModel, configModelRepoAdder, parent, producerId); } /** @@ -74,9 +82,12 @@ public final class ConfigModelContext { * @param producerId the id to be used for the config model * @return a model context that can be passed to a model */ - public static ConfigModelContext create(DeployState deployState, ConfigModelRepoAdder configModelRepoAdder, - AbstractConfigProducer parent, String producerId) { - return new ConfigModelContext(ApplicationType.DEFAULT, deployState, configModelRepoAdder, parent, producerId); + public static ConfigModelContext create(DeployState deployState, + VespaModel vespaModel, + ConfigModelRepoAdder configModelRepoAdder, + AbstractConfigProducer parent, + String producerId) { + return new ConfigModelContext(ApplicationType.DEFAULT, deployState, vespaModel, configModelRepoAdder, parent, producerId); } /** @@ -90,10 +101,11 @@ public final class ConfigModelContext { */ public static ConfigModelContext create(ApplicationType applicationType, DeployState deployState, + VespaModel vespaModel, ConfigModelRepoAdder configModelRepoAdder, AbstractConfigProducer parent, String producerId) { - return new ConfigModelContext(applicationType, deployState, configModelRepoAdder, parent, producerId); + return new ConfigModelContext(applicationType, deployState, vespaModel, configModelRepoAdder, parent, producerId); } public enum ApplicationType { diff --git a/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java b/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java index 5ec34b62ca2..60089f04572 100644 --- a/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java +++ b/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepo.java @@ -15,6 +15,7 @@ import com.yahoo.log.LogLevel; import com.yahoo.path.Path; import com.yahoo.text.XML; import com.yahoo.config.model.producer.AbstractConfigProducer; +import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.builder.VespaModelBuilder; import com.yahoo.vespa.model.clients.Clients; import com.yahoo.vespa.model.content.Content; @@ -71,10 +72,13 @@ public class ConfigModelRepo implements ConfigModelRepoAdder, Serializable, Iter public Map<String,ConfigModel> asMap() { return Collections.unmodifiableMap(configModelMap); } /** Initialize part 1.: Reads the config models used in the application package. */ - public void readConfigModels(DeployState deployState, VespaModelBuilder builder, - ApplicationConfigProducerRoot root, ConfigModelRegistry configModelRegistry) throws IOException, SAXException { + public void readConfigModels(DeployState deployState, + VespaModel vespaModel, + VespaModelBuilder builder, + ApplicationConfigProducerRoot root, + ConfigModelRegistry configModelRegistry) throws IOException, SAXException { Element userServicesElement = getServicesFromApp(deployState.getApplicationPackage()); - readConfigModels(root, userServicesElement, deployState, configModelRegistry); + readConfigModels(root, userServicesElement, deployState, vespaModel, configModelRegistry); builder.postProc(root, this); } @@ -104,7 +108,11 @@ public class ConfigModelRepo implements ConfigModelRepoAdder, Serializable, Iter * @param servicesRoot XML root node of the services file */ @SuppressWarnings("deprecation") - private void readConfigModels(ApplicationConfigProducerRoot root, Element servicesRoot, DeployState deployState, ConfigModelRegistry configModelRegistry) throws IOException, SAXException { + private void readConfigModels(ApplicationConfigProducerRoot root, + Element servicesRoot, + DeployState deployState, + VespaModel vespaModel, + ConfigModelRegistry configModelRegistry) throws IOException, SAXException { final Map<ConfigModelBuilder, List<Element>> model2Element = new LinkedHashMap<>(); ModelGraphBuilder graphBuilder = new ModelGraphBuilder(); @@ -140,7 +148,7 @@ public class ConfigModelRepo implements ConfigModelRepoAdder, Serializable, Iter } for (ModelNode node : graphBuilder.build().topologicalSort()) - buildModels(node, getApplicationType(servicesRoot), deployState, root, model2Element.get(node.builder)); + buildModels(node, getApplicationType(servicesRoot), deployState, vespaModel, root, model2Element.get(node.builder)); for (ConfigModel model : configModels) model.initialize(ConfigModelRepo.this); // XXX deprecated } @@ -174,10 +182,11 @@ public class ConfigModelRepo implements ConfigModelRepoAdder, Serializable, Iter private void buildModels(ModelNode node, ApplicationType applicationType, DeployState deployState, + VespaModel vespaModel, AbstractConfigProducer parent, List<Element> elements) { for (Element servicesElement : elements) { - ConfigModel model = buildModel(node, applicationType, deployState, parent, servicesElement); + ConfigModel model = buildModel(node, applicationType, deployState, vespaModel, parent, servicesElement); if (model.isServing()) add(model); } @@ -186,10 +195,11 @@ public class ConfigModelRepo implements ConfigModelRepoAdder, Serializable, Iter private ConfigModel buildModel(ModelNode node, ApplicationType applicationType, DeployState deployState, + VespaModel vespaModel, AbstractConfigProducer parent, Element servicesElement) { ConfigModelBuilder builder = node.builder; - ConfigModelContext context = ConfigModelContext.create(applicationType, deployState, this, parent, getIdString(servicesElement)); + ConfigModelContext context = ConfigModelContext.create(applicationType, deployState, vespaModel, this, parent, getIdString(servicesElement)); return builder.build(node, servicesElement, context); } diff --git a/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepoAdder.java b/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepoAdder.java index b16b65b9540..731cae48881 100644 --- a/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepoAdder.java +++ b/config-model/src/main/java/com/yahoo/config/model/ConfigModelRepoAdder.java @@ -5,7 +5,7 @@ package com.yahoo.config.model; * An interface which provides addition of new config models. * This exists because some models need to add additional models during the build phase so *write* access * to the config model repo is needed. *Read* access, on the other hand needs to happen through config model dependency - * inkection to avoid circular dependencies or undeclared dependencies working by accident. + * injection to avoid circular dependencies or undeclared dependencies working by accident. * * @author bratseth */ diff --git a/config-model/src/main/java/com/yahoo/config/model/builder/xml/ConfigModelBuilder.java b/config-model/src/main/java/com/yahoo/config/model/builder/xml/ConfigModelBuilder.java index bb0d24f9b26..b1e197db4eb 100644 --- a/config-model/src/main/java/com/yahoo/config/model/builder/xml/ConfigModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/config/model/builder/xml/ConfigModelBuilder.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.ConfigModelRepo; import com.yahoo.config.model.api.ConfigModelPlugin; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.producer.AbstractConfigProducer; +import com.yahoo.vespa.model.VespaModel; import org.w3c.dom.Element; import java.lang.reflect.Constructor; @@ -54,9 +55,9 @@ public abstract class ConfigModelBuilder<MODEL extends ConfigModel> extends Abst * @param parent the root config producer this should be added to * @param spec the XML element this is constructed from */ - public final MODEL build(DeployState deployState, ConfigModelRepo configModelRepo, + public final MODEL build(DeployState deployState, VespaModel vespaModel, ConfigModelRepo configModelRepo, AbstractConfigProducer parent, Element spec) { - ConfigModelContext context = ConfigModelContext.create(deployState, configModelRepo, parent, getIdString(spec)); + ConfigModelContext context = ConfigModelContext.create(deployState, vespaModel, configModelRepo, parent, getIdString(spec)); return build(new DefaultModelInstanceFactory(), spec, context); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/DefaultRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/DefaultRankProfile.java index 82d35c89e85..cbbcee0dcfa 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/DefaultRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/DefaultRankProfile.java @@ -31,31 +31,27 @@ public class DefaultRankProfile extends RankProfile { public void setInherited(String inheritedName) { } - /** - * Returns null, the default rank profile can not inherit anything - */ + /** Returns null, the default rank profile can not inherit anything */ public String getInheritedName() { return null; } - /** - * Returns the rank boost value of the given field - */ + /** Returns the rank boost value of the given field */ public RankSetting getRankSetting(String fieldOrIndex,RankSetting.Type type) { - RankSetting setting=super.getRankSetting(fieldOrIndex,type); - if (setting!=null) return setting; + RankSetting setting = super.getRankSetting(fieldOrIndex,type); + if (setting != null) return setting; - SDField field=getSearch().getConcreteField(fieldOrIndex); - if (field!=null) { - setting=toRankSetting(field,type); - if (setting!=null) + SDField field = getSearch().getConcreteField(fieldOrIndex); + if (field != null) { + setting = toRankSetting(field,type); + if (setting != null) return setting; } - Index index=getSearch().getIndex(fieldOrIndex); - if (index!=null) { - setting=toRankSetting(index,type); - if (setting!=null) + Index index = getSearch().getIndex(fieldOrIndex); + if (index != null) { + setting = toRankSetting(index,type); + if (setting != null) return setting; } @@ -92,37 +88,36 @@ public class DefaultRankProfile extends RankProfile { * explicitly in this profile or in fields */ public Set<RankSetting> rankSettings() { - Set<RankSetting> settings=new LinkedHashSet<>(20); + Set<RankSetting> settings = new LinkedHashSet<>(20); settings.addAll(this.rankSettings); for (SDField field : getSearch().allConcreteFields() ) { - addSetting(field,RankSetting.Type.WEIGHT,settings); - addSetting(field,RankSetting.Type.RANKTYPE,settings); - addSetting(field,RankSetting.Type.LITERALBOOST,settings); - addSetting(field,RankSetting.Type.PREFERBITVECTOR,settings); + addSetting(field, RankSetting.Type.WEIGHT, settings); + addSetting(field, RankSetting.Type.RANKTYPE, settings); + addSetting(field, RankSetting.Type.LITERALBOOST, settings); + addSetting(field, RankSetting.Type.PREFERBITVECTOR, settings); } // Foer settings that really pertains to indexes do the explicit indexes too for (Index index : getSearch().getExplicitIndices()) { - addSetting(index,RankSetting.Type.PREFERBITVECTOR,settings); + addSetting(index, RankSetting.Type.PREFERBITVECTOR, settings); } return settings; } - private void addSetting(SDField field,RankSetting.Type type,Set<RankSetting> settings) { + private void addSetting(SDField field, RankSetting.Type type, Set<RankSetting> settings) { if (type.isIndexLevel()) { - addIndexSettings(field,type,settings); + addIndexSettings(field, type, settings); } else { - RankSetting setting=toRankSetting(field,type); - if (setting==null) return; + RankSetting setting = toRankSetting(field, type); + if (setting == null) return; settings.add(setting); } } - private void addIndexSettings(SDField field,RankSetting.Type type,Set<RankSetting> settings) { + private void addIndexSettings(SDField field, RankSetting.Type type, Set<RankSetting> settings) { for (Iterator i = field.getFieldNameAsIterator(); i.hasNext(); ) { - String indexName=(String)i.next(); - Index explicitIndex=field.getIndex(indexName); + String indexName = (String)i.next(); // TODO: Make a ranking object in the index override the field level ranking object if (type.equals(RankSetting.Type.PREFERBITVECTOR) && field.getRanking().isFilter()) { @@ -131,9 +126,9 @@ public class DefaultRankProfile extends RankProfile { } } - private void addSetting(Index index,RankSetting.Type type,Set<RankSetting> settings) { - RankSetting setting=toRankSetting(index,type); - if (setting==null) return; + private void addSetting(Index index, RankSetting.Type type, Set<RankSetting> settings) { + RankSetting setting = toRankSetting(index, type); + if (setting == null) return; settings.add(setting); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 2e66784527d..9d6a1351724 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -20,6 +20,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.vespa.model.VespaModel; import java.io.File; import java.io.IOException; @@ -39,6 +40,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Represents a rank profile - a named set of ranking settings @@ -50,8 +52,11 @@ public class RankProfile implements Serializable, Cloneable { /** The search definition-unique name of this rank profile */ private final String name; - /** The search definition owning this profile, or null if none */ - private Search search = null; + /** The search definition owning this profile, or null if global (owned by a model) */ + private final Search search; + + /** The model owning this profile if it is global, or null if it is owned by a search definition */ + private final VespaModel model; /** The name of the rank profile inherited by this */ private String inheritedName = null; @@ -110,63 +115,80 @@ public class RankProfile implements Serializable, Cloneable { private final TypeSettings queryFeatureTypes = new TypeSettings(); /** - * Creates a new rank profile + * Creates a new rank profile for a particular search definition * * @param name the name of the new profile * @param search the search definition owning this profile - * @param rankProfileRegistry The {@link com.yahoo.searchdefinition.RankProfileRegistry} to use for storing + * @param rankProfileRegistry the {@link com.yahoo.searchdefinition.RankProfileRegistry} to use for storing * and looking up rank profiles. */ public RankProfile(String name, Search search, RankProfileRegistry rankProfileRegistry) { - this.name = name; - this.search = search; + this.name = Objects.requireNonNull(name, "name cannot be null"); + this.search = Objects.requireNonNull(search, "search cannot be null"); + this.model = null; this.rankProfileRegistry = rankProfileRegistry; } - public String getName() { return name; } - /** - * Returns the search definition owning this, or null if none + * Creates a global rank profile * - * @return The search definition. + * @param name the name of the new profile + * @param model the model owning this profile */ - public Search getSearch() { - return search; + public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry) { + this.name = Objects.requireNonNull(name, "name cannot be null"); + this.search = null; + this.model = Objects.requireNonNull(model, "model cannot be null"); + this.rankProfileRegistry = rankProfileRegistry; + } + + public String getName() { return name; } + + /** Returns the search definition owning this, or null if it is global */ + public Search getSearch() { return search; } + + /** Returns the application this is part of */ + public ApplicationPackage applicationPackage() { + return search != null ? search.applicationPackage() : model.applicationPackage(); + } + + /** Returns the ranking constants of the owner of this */ + public RankingConstants rankingConstants() { + return search != null ? search.rankingConstants() : model.rankingConstants(); + } + + private Stream<ImmutableSDField> allFields() { + return search != null ? search.allFields() : Stream.empty(); + } + + private Stream<ImmutableSDField> allImportedFields() { + return search != null ? search.allImportedFields() : Stream.empty(); } /** * Sets the name of the rank profile this inherits. Both rank profiles must be present in the same search * definition - * - * @param inheritedName The name of the profile that this inherits from. */ public void setInherited(String inheritedName) { this.inheritedName = inheritedName; } - /** - * Returns the name of the profile this one inherits, or null if none is inherited - * - * @return The inherited name. - */ + /** Returns the name of the profile this one inherits, or null if none is inherited */ public String getInheritedName() { return inheritedName; } - /** - * Returns the inherited rank profile, or null if there is none - * - * @return The inherited profile. - */ + /** Returns the inherited rank profile, or null if there is none */ public RankProfile getInherited() { - if (getSearch()==null) return getInheritedFromRegistry(inheritedName); - RankProfile inheritedInThisSearch = rankProfileRegistry.getRankProfile(search, inheritedName); - if (inheritedInThisSearch!=null) return inheritedInThisSearch; + if (getSearch() == null) return getInheritedFromRegistry(inheritedName); + + RankProfile inheritedInThisSearch = rankProfileRegistry.get(search, inheritedName); + if (inheritedInThisSearch != null) return inheritedInThisSearch; return getInheritedFromRegistry(inheritedName); } private RankProfile getInheritedFromRegistry(String inheritedName) { - for (RankProfile r : rankProfileRegistry.allRankProfiles()) { + for (RankProfile r : rankProfileRegistry.all()) { if (r.getName().equals(inheritedName)) { return r; } @@ -177,8 +199,8 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns whether this profile inherits (directly or indirectly) the given profile * - * @param name The profile name to compare this to. - * @return Whether or not this inherits from the named profile. + * @param name the profile name to compare this to. + * @return whether or not this inherits from the named profile. */ public boolean inherits(String name) { RankProfile parent = getInherited(); @@ -190,10 +212,6 @@ public class RankProfile implements Serializable, Cloneable { return false; } - /** - * change match settings - * @param settings The new match settings - **/ public void setMatchPhaseSettings(MatchPhaseSettings settings) { settings.checkValid(); this.matchPhaseSettings = settings; @@ -219,7 +237,7 @@ public class RankProfile implements Serializable, Cloneable { * * @param field The field whose settings to return. * @param type The type that the field is required to be. - * @return The rank setting found, or null. + * @return the rank setting found, or null. */ public RankSetting getDeclaredRankSetting(String field, RankSetting.Type type) { for (Iterator<RankSetting> i = declaredRankSettingIterator(); i.hasNext();) { @@ -236,9 +254,9 @@ public class RankProfile implements Serializable, Cloneable { * Returns a rank setting of field or index, or null if there is no such rank setting in this profile or one it * inherits * - * @param field The field whose settings to return. - * @param type The type that the field is required to be. - * @return The rank setting found, or null. + * @param field the field whose settings to return + * @param type the type that the field is required to be + * @return the rank setting found, or null */ public RankSetting getRankSetting(String field, RankSetting.Type type) { RankSetting rankSetting = getDeclaredRankSetting(field, type); @@ -252,7 +270,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns the rank settings in this rank profile * - * @return An iterator for the declared rank setting. + * @return an iterator for the declared rank setting */ public Iterator<RankSetting> declaredRankSettingIterator() { return Collections.unmodifiableSet(rankSettings).iterator(); @@ -261,14 +279,14 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns all settings in this profile or any profile it inherits * - * @return An iterator for all rank settings of this. + * @return an iterator for all rank settings of this */ public Iterator<RankSetting> rankSettingIterator() { return rankSettings().iterator(); } /** - * Returns a snapshot of the rank settings of this and everything it inherits + * Returns a snapshot of the rank settings of this and everything it inherits. * Changes to the returned set will not be reflected in this rank profile. */ public Set<RankSetting> rankSettings() { @@ -346,6 +364,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Called by parser to store the expression string, for delayed evaluation + * * @param exp ranking expression for second phase */ public void setSecondPhaseRankingString(String exp) { @@ -354,6 +373,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Called by parser to store the expression string, for delayed evaluation + * * @param exp ranking expression for first phase */ public void setFirstPhaseRankingString(String exp) { @@ -528,8 +548,11 @@ public class RankProfile implements Serializable, Cloneable { return null; } - public void addMacro(String name, boolean inline) { - macros.put(name, new Macro(name, inline)); + /** Creates a new (empty) macro and returns it */ + public Macro addMacro(String name, boolean inline) { + Macro macro = new Macro(name, inline); + macros.put(name, macro); + return macro; } /** Returns an unmodifiable view of the macros in this */ @@ -571,6 +594,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns all filter fields in this profile and any profile it inherits. + * * @return the set of all filter fields */ public Set<String> allFilterFields() { @@ -770,11 +794,11 @@ public class RankProfile implements Serializable, Cloneable { // Add small and large constants, respectively getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type())); - getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType())); + rankingConstants().asMap().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType())); // Add attributes - getSearch().allFields().forEach(field -> addAttributeFeatureTypes(field, context)); - getSearch().allImportedFields().forEach(field -> addAttributeFeatureTypes(field, context)); + allFields().forEach(field -> addAttributeFeatureTypes(field, context)); + allImportedFields().forEach(field -> addAttributeFeatureTypes(field, context)); // Add query features from rank profile types reached from the "default" profile for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) { @@ -868,7 +892,7 @@ public class RankProfile implements Serializable, Cloneable { public Object getValue() { return value; } - /** @return The value as an int, or a negative value if it is not an integer */ + /** Returns the value as an int, or a negative value if it is not an integer */ public int getIntValue() { if (value instanceof Integer) { return ((Integer)value); @@ -934,8 +958,6 @@ public class RankProfile implements Serializable, Cloneable { /** * Represents a declared macro in the profile. It is, after parsing, transformed into ExpressionMacro - * - * @author vegardh */ public static class Macro implements Serializable, Cloneable { @@ -1021,6 +1043,7 @@ public class RankProfile implements Serializable, Cloneable { public int getMinGroups() { return minGroups; } public double getCutoffFactor() { return cutoffFactor; } public Diversity.CutoffStrategy getCutoffStrategy() { return cutoffStrategy; } + public void checkValid() { if (attribute == null || attribute.isEmpty()) { throw new IllegalArgumentException("'diversity' did not set non-empty diversity attribute name."); @@ -1035,6 +1058,7 @@ public class RankProfile implements Serializable, Cloneable { } public static class MatchPhaseSettings { + private String attribute = null; private boolean ascending = false; private int maxHits = 0; // try to get this many hits before degrading the match phase @@ -1047,6 +1071,7 @@ public class RankProfile implements Serializable, Cloneable { value.checkValid(); diversity = value; } + public void setAscending(boolean value) { ascending = value; } public void setAttribute(String value) { attribute = value; } public void setMaxHits(int value) { maxHits = value; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java index 9e1e42e0821..53afebfd93b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java @@ -1,8 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import com.yahoo.searchdefinition.expressiontransforms.ExpressionTransforms; - import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -16,7 +14,7 @@ import java.util.Set; * Having both of these mappings consolidated here make it easier to remove dependencies on these mappings at * run time, since it is essentially only used when building rank profile config at deployment time. * - * TODO: Rank profiles should be stored under its owning Search instance. + * Global rank profiles are represented by the Search key null. * * @author Ulf Lilleengen */ @@ -30,8 +28,8 @@ public class RankProfileRegistry { public static RankProfileRegistry createRankProfileRegistryWithBuiltinRankProfiles(Search search) { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - rankProfileRegistry.addRankProfile(new DefaultRankProfile(search, rankProfileRegistry)); - rankProfileRegistry.addRankProfile(new UnrankedRankProfile(search, rankProfileRegistry)); + rankProfileRegistry.add(new DefaultRankProfile(search, rankProfileRegistry)); + rankProfileRegistry.add(new UnrankedRankProfile(search, rankProfileRegistry)); return rankProfileRegistry; } @@ -40,51 +38,54 @@ public class RankProfileRegistry { * * @param rankProfile the rank profile to add */ - public void addRankProfile(RankProfile rankProfile) { + public void add(RankProfile rankProfile) { if ( ! rankProfiles.containsKey(rankProfile.getSearch())) { rankProfiles.put(rankProfile.getSearch(), new LinkedHashMap<>()); } - checkForDuplicateRankProfile(rankProfile); + checkForDuplicate(rankProfile); rankProfiles.get(rankProfile.getSearch()).put(rankProfile.getName(), rankProfile); rankProfileToSearch.put(rankProfile, rankProfile.getSearch()); } - private void checkForDuplicateRankProfile(RankProfile rankProfile) { + private void checkForDuplicate(RankProfile rankProfile) { String rankProfileName = rankProfile.getName(); RankProfile existingRangProfileWithSameName = rankProfiles.get(rankProfile.getSearch()).get(rankProfileName); if (existingRangProfileWithSameName == null) return; if ( ! overridableRankProfileNames.contains(rankProfileName)) { throw new IllegalArgumentException("Cannot add rank profile '" + rankProfileName + "' in search definition '" - + rankProfile.getSearch().getName() + "', since it already exists"); + + rankProfile.getSearch().getName() + "', since it already exists"); } } /** * Returns a named rank profile, null if the search definition doesn't have one with the given name * - * @param search The {@link Search} that owns the rank profile. - * @param name The name of the rank profile - * @return The RankProfile to return. + * @param search the {@link Search} that owns the rank profile. + * @param name the name of the rank profile + * @return the RankProfile to return. */ - public RankProfile getRankProfile(Search search, String name) { - return rankProfiles.get(search).get(name); + public RankProfile get(Search search, String name) { + Map<String, RankProfile> profiles = rankProfiles.get(search); + if (profiles == null) return null; + return profiles.get(name); } /** * Rank profiles that are collected across clusters. * @return A set of global {@link RankProfile} instances. */ - public Set<RankProfile> allRankProfiles() { + public Set<RankProfile> all() { return rankProfileToSearch.keySet(); } /** - * Rank profiles that are collected for a given search definition - * @param search {@link Search} to get rank profiles for. - * @return A collection of local {@link RankProfile} instances. + * Returns the rank profiles of a given search definition. + * + * @param search {@link Search} to get rank profiles for + * @return a collection of {@link RankProfile} instances */ - public Collection<RankProfile> localRankProfiles(Search search) { + public Collection<RankProfile> rankProfilesOf(Search search) { Map<String, RankProfile> mapping = rankProfiles.get(search); if (mapping == null) { return Collections.emptyList(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java new file mode 100644 index 00000000000..164cb7f808e --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Constant values for ranking/model execution tied to a search definition, or globally to an application + * package + * + * @author bratseth + */ +public class RankingConstants { + + private final Map<String, RankingConstant> constants = new HashMap<>(); + + public void add(RankingConstant constant) { + constant.validate(); + String name = constant.getName(); + if (constants.containsKey(name)) + throw new IllegalArgumentException("Ranking constant '" + name + "' defined twice"); + constants.put(name, constant); + } + + /** Returns the ranking constant with the given name, or null if not present */ + public RankingConstant get(String name) { + return constants.get(name); + } + + /** Returns a read-only map of the ranking constants in this indexed by name */ + public Map<String, RankingConstant> asMap() { + return Collections.unmodifiableMap(constants); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index 1ab76afc9c0..f42d5de21e8 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -82,16 +82,16 @@ public class Search implements Serializable, ImmutableSearch { private Map<String, Index> indices = new LinkedHashMap<>(); // The explicitly defined summaries of this search definition. - private Map<String, DocumentSummary> summaries = new LinkedHashMap<>(); // _Must_ preserve order + private Map<String, DocumentSummary> summaries = new LinkedHashMap<>(); - // Ranking constants defined inside this s.d. - private Map<String, RankingConstant> rankingConstants = new HashMap<>(); + // Ranking constants of this + private RankingConstants rankingConstants = new RankingConstants(); private Optional<TemporaryImportedFields> temporaryImportedFields = Optional.of(new TemporaryImportedFields()); private Optional<ImportedFields> importedFields = Optional.empty(); - private ApplicationPackage sourceApplication; + private ApplicationPackage applicationPackage; /** * Creates a search definition which just holds a set of documents which should not (here, directly) be searchable @@ -103,10 +103,10 @@ public class Search implements Serializable, ImmutableSearch { /** * Creates a proper search definition * @param name of the the searchdefinition - * @param sourceApplication the application containing this + * @param applicationPackage the application containing this */ - public Search(String name, ApplicationPackage sourceApplication) { - this.sourceApplication = sourceApplication; + public Search(String name, ApplicationPackage applicationPackage) { + this.applicationPackage = applicationPackage; this.name = name; } @@ -162,18 +162,7 @@ public class Search implements Serializable, ImmutableSearch { docType = document; } - public void addRankingConstant(RankingConstant constant) { - constant.validate(); - String name = constant.getName(); - if (rankingConstants.containsKey(name)) - throw new IllegalArgumentException("Ranking constant '" + name + "' defined twice"); - rankingConstants.put(name, constant); - } - - /** Returns a read-only map of the ranking constants in this indexed by name */ - public Map<String, RankingConstant> getRankingConstants() { - return Collections.unmodifiableMap(rankingConstants); - } + public RankingConstants rankingConstants() { return rankingConstants; } public Optional<TemporaryImportedFields> temporaryImportedFields() { return temporaryImportedFields; @@ -260,10 +249,10 @@ public class Search implements Serializable, ImmutableSearch { * Returns the content of a ranking expression file */ public Reader getRankingExpression(String fileName) { - return sourceApplication.getRankingExpression(fileName); + return applicationPackage.getRankingExpression(fileName); } - public ApplicationPackage sourceApplication() { return sourceApplication; } + public ApplicationPackage applicationPackage() { return applicationPackage; } /** * Returns a field defined in this search definition or one if its documents. Fields in this search definition takes diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java index 55f3a94bb70..a3580a404a3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java @@ -41,8 +41,11 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce /** Whether this has any position attribute */ private boolean hasPosition = false; + public static final AttributeFields empty = new AttributeFields(null); + public AttributeFields(Search search) { - derive(search); + if (search != null) + derive(search); } /** Derives everything from a field */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 1e978e43d6a..10881ab9ce0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -12,11 +12,16 @@ import java.util.Map; /** * The derived rank profiles of a search definition * - * @author bratseth + * @author bratseth */ public class RankProfileList extends Derived implements RankProfilesConfig.Producer { - private Map<String, RawRankProfile> rankProfiles = new java.util.LinkedHashMap<>(); + private final Map<String, RawRankProfile> rankProfiles = new java.util.LinkedHashMap<>(); + + public static RankProfileList empty = new RankProfileList(); + + private RankProfileList() { + } /** * Creates a rank profile @@ -29,7 +34,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ RankProfileRegistry rankProfileRegistry, QueryProfileRegistry queryProfiles, ImportedModels importedModels) { - setName(search.getName()); + setName(search == null ? "default" : search.getName()); deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields); } @@ -38,14 +43,16 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ ImportedModels importedModels, Search search, AttributeFields attributeFields) { - RawRankProfile defaultProfile = new RawRankProfile(rankProfileRegistry.getRankProfile(search, "default"), - queryProfiles, - importedModels, - attributeFields); - rankProfiles.put(defaultProfile.getName(), defaultProfile); + if (search != null) { // profiles belonging to a search have a default profile + RawRankProfile defaultProfile = new RawRankProfile(rankProfileRegistry.get(search, "default"), + queryProfiles, + importedModels, + attributeFields); + rankProfiles.put(defaultProfile.getName(), defaultProfile); + } - for (RankProfile rank : rankProfileRegistry.localRankProfiles(search)) { - if ("default".equals(rank.getName())) continue; + for (RankProfile rank : rankProfileRegistry.rankProfilesOf(search)) { + if (search != null && "default".equals(rank.getName())) continue; RawRankProfile rawRank = new RawRankProfile(rank, queryProfiles, importedModels, attributeFields); rankProfiles.put(rawRank.getName(), rawRank); @@ -70,4 +77,5 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ rank.getConfig(builder); } } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index a38fbe1aaa0..629fa9624c5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -81,7 +81,7 @@ public class ConvertedModel { public ConvertedModel(Path modelPath, RankProfileTransformContext context) { this.modelPath = modelPath; this.modelName = toModelName(modelPath); - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); + ModelStore store = new ModelStore(context.rankProfile().applicationPackage(), modelPath); if ( store.hasSourceModel()) expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels()); else @@ -145,39 +145,11 @@ public class ConvertedModel { // Add expressions Map<String, RankingExpression> expressions = new HashMap<>(); - for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) { - for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) { - addExpression(model.expressions().get(outputEntry.getValue()), - modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs - addExpression(model.expressions().get(signatureEntry.getKey()), - modelName + "." + signatureEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - } - if (model.signatures().isEmpty()) { // fallback: Model without signatures - if (model.expressions().size() == 1) { // Use just model name - addExpression(model.expressions().values().iterator().next(), - modelName, - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - else { - for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) { - addExpression(expressionEntry.getValue(), - modelName + "." + expressionEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - } + for (Pair<String, RankingExpression> output : model.outputExpressions()) { + addExpression(output.getSecond(), output.getFirst(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); } // Transform and save macro - must come after reading expressions due to optimization transforms @@ -209,8 +181,8 @@ public class ConvertedModel { profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) - profile.getSearch().addRankingConstant(constant); + if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) + profile.rankingConstants().add(constant); } for (Pair<String, RankingExpression> macro : store.readMacros()) { @@ -238,9 +210,9 @@ public class ConvertedModel { } else { Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); + if ( ! profile.rankingConstants().asMap().containsKey(constantName)) { + profile.rankingConstants().add(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/DiversitySettingsValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/DiversitySettingsValidator.java index a936045af3a..6b78da2146b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/DiversitySettingsValidator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/DiversitySettingsValidator.java @@ -21,7 +21,7 @@ public class DiversitySettingsValidator extends Processor { public void process(boolean validate) { if ( ! validate) return; - for (RankProfile rankProfile : rankProfileRegistry.localRankProfiles(search)) { + for (RankProfile rankProfile : rankProfileRegistry.rankProfilesOf(search)) { if (rankProfile.getMatchPhaseSettings() != null && rankProfile.getMatchPhaseSettings().getDiversity() != null) { validate(rankProfile, rankProfile.getMatchPhaseSettings().getDiversity()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/FilterFieldNames.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/FilterFieldNames.java index 39d35cce694..0c75314ffa2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/FilterFieldNames.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/FilterFieldNames.java @@ -33,7 +33,7 @@ public class FilterFieldNames extends Processor { } } - for (RankProfile profile : rankProfileRegistry.localRankProfiles(search)) { + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { Set<String> filterFields = new LinkedHashSet<>(); findFilterFields(search, profile, filterFields); for (Iterator<String> itr = filterFields.iterator(); itr.hasNext(); ) { @@ -45,7 +45,7 @@ public class FilterFieldNames extends Processor { } private void filterField(String f) { - for (RankProfile rp : rankProfileRegistry.localRankProfiles(search)) { + for (RankProfile rp : rankProfileRegistry.rankProfilesOf(search)) { rp.filterFields().add(f); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/MatchPhaseSettingsValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/MatchPhaseSettingsValidator.java index 043eb1f82eb..479384e09ef 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/MatchPhaseSettingsValidator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/MatchPhaseSettingsValidator.java @@ -23,7 +23,7 @@ public class MatchPhaseSettingsValidator extends Processor { public void process(boolean validate) { if ( ! validate) return; - for (RankProfile rankProfile : rankProfileRegistry.localRankProfiles(search)) { + for (RankProfile rankProfile : rankProfileRegistry.rankProfilesOf(search)) { RankProfile.MatchPhaseSettings settings = rankProfile.getMatchPhaseSettings(); if (settings != null) { validateMatchPhaseSettings(rankProfile, settings); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processor.java index b0fbc6c1998..b938e40d9a2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processor.java @@ -100,7 +100,7 @@ public abstract class Processor { { List<RankProfile.RankSetting> someRankSettings = new java.util.ArrayList<>(); - for (RankProfile profile : rankProfileRegistry.localRankProfiles(search)) { + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { for (Iterator j = profile.declaredRankSettingIterator(); j.hasNext(); ) { RankProfile.RankSetting setting = (RankProfile.RankSetting)j.next(); if (setting.getType().equals(type)) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java index f7f314f8444..81455991cc9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java @@ -36,7 +36,7 @@ public class RankingExpressionTypeValidator extends Processor { public void process(boolean validate) { if ( ! validate) return; - for (RankProfile profile : rankProfileRegistry.localRankProfiles(search)) { + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { try { validate(profile); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/ReservedMacroNames.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/ReservedMacroNames.java index b8eb1e1d8cf..adcebed9254 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/ReservedMacroNames.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/ReservedMacroNames.java @@ -9,7 +9,6 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchlib.rankingexpression.parser.RankingExpressionParserConstants; import com.yahoo.vespa.model.container.search.QueryProfiles; -import java.util.HashSet; import java.util.Set; import java.util.logging.Level; @@ -31,7 +30,7 @@ public class ReservedMacroNames extends Processor { public void process(boolean validate) { if ( ! validate) return; - for (RankProfile rp : rankProfileRegistry.allRankProfiles()) { + for (RankProfile rp : rankProfileRegistry.all()) { for (String macroName : rp.getMacros().keySet()) { if (reservedNames.contains(macroName)) { deployLogger.log(Level.WARNING, "Macro \"" + macroName + "\" " + diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java index 21567edb94b..cc1638347f6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java @@ -66,7 +66,7 @@ public class RankProfileTypeSettingsProcessor extends Processor { } private void addAttributeTypeToRankProfiles(String attributeName, String attributeType) { - for (RankProfile profile : rankProfileRegistry.allRankProfiles()) { + for (RankProfile profile : rankProfileRegistry.all()) { profile.addAttributeType(attributeName, attributeType); } } @@ -90,7 +90,7 @@ public class RankProfileTypeSettingsProcessor extends Processor { } private void addQueryFeatureTypeToRankProfiles(String queryFeature, String queryFeatureType) { - for (RankProfile profile : rankProfileRegistry.allRankProfiles()) { + for (RankProfile profile : rankProfileRegistry.all()) { profile.addQueryFeatureType(queryFeature, queryFeatureType); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index f3e7a9623d1..73dd60f63eb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model; +import com.google.common.collect.ImmutableList; +import com.yahoo.collections.Pair; import com.yahoo.config.ConfigBuilder; import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; @@ -18,13 +20,20 @@ import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.FileDistribution; import com.yahoo.config.model.api.HostInfo; import com.yahoo.config.model.api.Model; -import com.yahoo.config.model.api.ValidationParameters; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.config.model.producer.AbstractConfigProducerRoot; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.log.LogLevel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.RankingConstants; +import com.yahoo.searchdefinition.derived.AttributeFields; +import com.yahoo.searchdefinition.derived.RankProfileList; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigKey; import com.yahoo.vespa.config.ConfigPayload; @@ -92,13 +101,19 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri */ public static final String ROOT_CONFIGID = ""; - private ApplicationConfigProducerRoot root = null; + private ApplicationConfigProducerRoot root; - /** - * Generic service instances - service clusters which have no specific model - */ + private final ApplicationPackage applicationPackage; + + /** Generic service instances - service clusters which have no specific model */ private List<ServiceCluster> serviceClusters = new ArrayList<>(); + /** The global rank profiles of this model */ + private final RankProfileList rankProfileList; + + /** The global ranking constants of this model */ + private final RankingConstants rankingConstants = new RankingConstants(); + private DeployState deployState; /** The validation overrides of this. This is never null. */ @@ -144,11 +159,21 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri this.validationOverrides = deployState.validationOverrides(); configModelRegistry = new VespaConfigModelRegistry(configModelRegistry); VespaModelBuilder builder = new VespaDomBuilder(); + this.applicationPackage = deployState.getApplicationPackage(); root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this); + + createGlobalRankProfiles(deployState.getImportedModels(), deployState.rankProfileRegistry()); + this.rankProfileList = new RankProfileList(null, // null search -> global + AttributeFields.empty, + deployState.rankProfileRegistry(), + deployState.getQueryProfiles().getRegistry(), + deployState.getImportedModels()); + if (complete) { // create a a completed, frozen model - configModelRepo.readConfigModels(deployState, builder, root, configModelRegistry); + configModelRepo.readConfigModels(deployState, this, builder, root, configModelRegistry); addServiceClusters(deployState.getApplicationPackage(), builder); this.allocatedHosts = AllocatedHosts.withHosts(root.getHostSystem().getHostSpecs()); // must happen after the two lines above + setupRouting(); this.fileDistributor = root.getFileDistributionConfigProducer().getFileDistributor(); getAdmin().addPerHostServices(getHostSystem().getHosts(), deployState); @@ -164,6 +189,12 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri } } + /** Returns the application package owning this */ + public ApplicationPackage applicationPackage() { return applicationPackage; } + + /** Returns the global ranking constants of this */ + public RankingConstants rankingConstants() { return rankingConstants; } + /** Creates a mutable model with no services instantiated */ public static VespaModel createIncomplete(DeployState deployState) throws IOException, SAXException { return new VespaModel(new NullConfigModelRegistry(), deployState, false, new FileDistributor(deployState.getFileRegistry(), null)); @@ -185,8 +216,27 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri serviceClusters.add(sc); } + /** + * Creates a rank profile not attached to any search definition, for each imported model in the application package + */ + private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels, + RankProfileRegistry rankProfileRegistry) { + List<RankProfile> profiles = new ArrayList<>(); + for (ImportedModel model : importedModels.all()) { + RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); + for (Pair<String, RankingExpression> entry : model.outputExpressions()) { + profile.addMacro(entry.getFirst(), false).setRankingExpression(entry.getSecond()); + } + rankProfileRegistry.add(profile); + } + return ImmutableList.copyOf(profiles); + } + + /** Returns the global rank profiles as a rank profile list */ + public RankProfileList rankProfileList() { return rankProfileList; } + private void setupRouting() { - root.setupRouting(configModelRepo); + root.setupRouting(this, configModelRepo); } /** Returns the one and only HostSystem of this VespaModel */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index d022b2cf8ab..907418ea9f0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -48,7 +48,7 @@ public class RankingConstantsValidator extends Validator { ExceptionMessageCollector exceptionMessageCollector = new ExceptionMessageCollector("Invalid constant tensor file(s):"); for (SearchDefinition sd : deployState.getSearchDefinitions()) { - for (RankingConstant rc : sd.getSearch().getRankingConstants().values()) { + for (RankingConstant rc : sd.getSearch().rankingConstants().asMap().values()) { try { validateRankingConstant(rc, applicationPackage); } catch (InvalidConstantTensor | FileNotFoundException ex) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java index 75cd755a91d..d67cb0c29c3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV2Builder.java @@ -5,6 +5,7 @@ import com.yahoo.config.application.api.FileRegistry; import com.yahoo.config.model.ConfigModelContext; import com.yahoo.config.model.api.ConfigServerSpec; import com.yahoo.config.model.producer.AbstractConfigProducer; +import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.text.XML; import com.yahoo.log.LogLevel; import com.yahoo.vespa.model.SimpleConfigProducer; @@ -89,7 +90,7 @@ public class DomAdminV2Builder extends DomAdminBuilderBase { if (standaloneZooKeeper) { parent = new ClusterControllerCluster(parent, "standalone"); } - ContainerCluster cluster = new ContainerCluster(parent, "cluster-controllers", "cluster-controllers", new ClusterControllerClusterVerifier()); + ContainerCluster cluster = new ContainerCluster(parent, "cluster-controllers", "cluster-controllers", new ClusterControllerClusterVerifier(), RankProfileList.empty); ContainerModelBuilder.addDefaultHandler_legacyBuilder(cluster); List<Container> containers = new ArrayList<>(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 469b6781bae..4c5dafb7d8f 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -40,6 +40,10 @@ import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.search.config.QrStartConfig; import com.yahoo.search.pagetemplates.PageTemplatesConfig; import com.yahoo.search.query.profile.config.QueryProfilesConfig; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.derived.AttributeFields; +import com.yahoo.searchdefinition.derived.RankProfileList; +import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.model.PortsMeta; import com.yahoo.vespa.model.Service; @@ -65,6 +69,7 @@ import com.yahoo.vespa.model.container.jersey.JerseyHandler; import com.yahoo.vespa.model.container.jersey.RestApi; import com.yahoo.vespa.model.container.processing.ProcessingChains; import com.yahoo.vespa.model.container.search.ContainerSearch; +import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.container.search.searchchain.SearchChains; import com.yahoo.vespa.model.content.Content; import com.yahoo.vespa.model.search.AbstractSearchCluster; @@ -122,7 +127,9 @@ public final class ContainerCluster ServletPathsConfig.Producer, RoutingProviderConfig.Producer, ConfigserverConfig.Producer, - ThreadpoolConfig.Producer + ThreadpoolConfig.Producer, + RankProfilesConfig.Producer + { /** @@ -169,6 +176,9 @@ public final class ContainerCluster private final ContainerClusterVerifier clusterVerifier; private final boolean isHostedVespa; + /** Global rank profiles, aka models */ + private final RankProfileList rankProfileList; + private Map<String, String> concreteDocumentTypes = new LinkedHashMap<>(); private MetricDefaultsConfig.Factory.Enum defaultMetricConsumerFactory; @@ -193,11 +203,30 @@ public final class ContainerCluster } } - public ContainerCluster(AbstractConfigProducer<?> parent, String subId, String name) { - this(parent, subId, name, new AcceptAllVerifier()); + /** + * Creates a container cluster + * + * @param rankProfileList the list ofd global rank profiles containing models that should be available in + * container clusters + */ + public ContainerCluster(AbstractConfigProducer<?> parent, + String subId, + String name, + RankProfileList rankProfileList) { + this(parent, subId, name, new AcceptAllVerifier(), rankProfileList); } - public ContainerCluster(AbstractConfigProducer<?> parent, String subId, String name, ContainerClusterVerifier verifier) { + /** + * Creates a container cluster + * + * @param rankProfileList the list ofd global rank profiles containing models that should be available in + * container clusters + */ + public ContainerCluster(AbstractConfigProducer<?> parent, + String subId, + String name, + ContainerClusterVerifier verifier, + RankProfileList rankProfileList) { super(parent, subId); this.clusterVerifier = verifier; this.name = name; @@ -207,6 +236,7 @@ public final class ContainerCluster componentGroup = new ComponentGroup<>(this, "component"); restApiGroup = new ConfigProducerGroup<>(this, "rest-api"); servletGroup = new ConfigProducerGroup<>(this, "servlet"); + this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null"); addComponent(new StatisticsComponent()); addSimpleComponent(AccessLog.class); @@ -694,6 +724,11 @@ public final class ContainerCluster containerDocproc.getConfig(builder); } + @Override + public void getConfig(RankProfilesConfig.Builder builder) { + rankProfileList.getConfig(builder); + } + public void setMbusParams(MbusParams mbusParams) { this.mbusParams = mbusParams; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index d81026c54d1..cb4cf92a223 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -25,6 +25,7 @@ import com.yahoo.config.provision.Zone; import com.yahoo.container.jdisc.config.MetricDefaultsConfig; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.search.rendering.RendererRegistry; +import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.text.XML; import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.model.AbstractService; @@ -143,11 +144,15 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } } - private ContainerCluster createContainerCluster(Element spec, final ConfigModelContext modelContext) { + private ContainerCluster createContainerCluster(Element spec, ConfigModelContext modelContext) { return new VespaDomBuilder.DomConfigProducerBuilder<ContainerCluster>() { @Override protected ContainerCluster doBuild(AbstractConfigProducer ancestor, Element producerSpec) { - return new ContainerCluster(ancestor, modelContext.getProducerId(), modelContext.getProducerId()); + return new ContainerCluster(ancestor, + modelContext.getProducerId(), + modelContext.getProducerId(), + modelContext.vespaModel() != null ? modelContext.vespaModel().rankProfileList() + : RankProfileList.empty); } }.build(modelContext.getParentProducer(), spec); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java b/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java index 7bd70bba87a..d3709e88f29 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/Content.java @@ -14,6 +14,7 @@ import com.yahoo.config.model.builder.xml.ConfigModelBuilder; import com.yahoo.config.model.builder.xml.ConfigModelId; import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.log.LogLevel; +import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.model.*; import com.yahoo.vespa.model.admin.Admin; import com.yahoo.vespa.model.container.Container; @@ -298,7 +299,7 @@ public class Content extends ConfigModel { AbstractConfigProducer parent = root.getChildren().get(ContainerModel.DOCPROC_RESERVED_NAME); if (parent == null) parent = new SimpleConfigProducer(root, ContainerModel.DOCPROC_RESERVED_NAME); - ContainerCluster indexingCluster = new ContainerCluster(parent, "cluster." + indexerName, indexerName); + ContainerCluster indexingCluster = new ContainerCluster(parent, "cluster." + indexerName, indexerName, RankProfileList.empty); ContainerModel indexingClusterModel = new ContainerModel(modelContext.withParent(parent).withId(indexingCluster.getSubId())); indexingClusterModel.setCluster(indexingCluster); modelContext.getConfigModelRepoAdder().add(indexingClusterModel); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java index 154f719ff10..cce367ed611 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java @@ -8,6 +8,7 @@ import com.yahoo.config.model.producer.AbstractConfigProducerRoot; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.Zone; +import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.config.content.MessagetyperouteselectorpolicyConfig; import com.yahoo.vespa.config.content.FleetcontrollerConfig; import com.yahoo.vespa.config.content.StorDistributionConfig; @@ -435,8 +436,15 @@ public class ContentCluster extends AbstractConfigProducer implements return sortedHosts; } - private ContainerCluster createClusterControllers(AbstractConfigProducer parent, Collection<HostResource> hosts, String name, boolean multitenant) { - ContainerCluster clusterControllers = new ContainerCluster(parent, name, name, new ClusterControllerClusterVerifier()); + private ContainerCluster createClusterControllers(AbstractConfigProducer parent, + Collection<HostResource> hosts, + String name, + boolean multitenant) { + ContainerCluster clusterControllers = new ContainerCluster(parent, + name, + name, + new ClusterControllerClusterVerifier(), + RankProfileList.empty); List<Container> containers = new ArrayList<>(); // Add a cluster controller on each config server (there is always at least one). if (clusterControllers.getContainers().isEmpty()) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index 9550cd82b22..83da5d96418 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -37,7 +37,7 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer public void prepareToDistributeFiles(List<SearchNode> backends) { for (SearchDefinitionSpec sds : localSDS) { - for (RankingConstant constant : sds.getSearchDefinition().getSearch().getRankingConstants().values()) { + for (RankingConstant constant : sds.getSearchDefinition().getSearch().rankingConstants().asMap().values()) { FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE) ? FileSender.sendFileToServices(constant.getFileName(), backends) : FileSender.sendUriToServices(constant.getUri(), backends); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index 1413d515103..a6bf51a2503 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -75,7 +75,7 @@ public class DocumentDatabase extends AbstractConfigProducer implements @Override public void getConfig(RankingConstantsConfig.Builder builder) { - for (RankingConstant constant : derivedCfg.getSearch().getRankingConstants().values()) { + for (RankingConstant constant : derivedCfg.getSearch().rankingConstants().asMap().values()) { if ("".equals(constant.getFileReference())) { System.err.println("INVALID rank constant "+constant.getName()+" [missing file reference]"); // TODO: Throw or log warning continue; diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj index fc4cbd0a495..12e61dea450 100644 --- a/config-model/src/main/javacc/SDParser.jj +++ b/config-model/src/main/javacc/SDParser.jj @@ -407,8 +407,8 @@ Search rootSearch(String dir) : deployLogger.log(Level.WARNING, name + " can not be used in YQL+ expressions."); } search = new Search(name, app); - rankProfileRegistry.addRankProfile(new DefaultRankProfile(search, rankProfileRegistry)); - rankProfileRegistry.addRankProfile(new UnrankedRankProfile(search, rankProfileRegistry));} + rankProfileRegistry.add(new DefaultRankProfile(search, rankProfileRegistry)); + rankProfileRegistry.add(new UnrankedRankProfile(search, rankProfileRegistry));} lbrace() (rootSearchItem(search) (<NL>)*)* <RBRACE> (<NL>)* <EOF>) { return search; } } @@ -1801,7 +1801,7 @@ void rankingConstant(Search search) : } lbrace() (rankingConstantItem(constant) (<NL>)*)+ <RBRACE> ) { - search.addRankingConstant(constant); + search.rankingConstants().add(constant); } } @@ -1857,7 +1857,7 @@ void rankProfile(Search search) : ( <RANKPROFILE> name = identifier() { if ("default".equals(name)) { - profile = rankProfileRegistry.getRankProfile(search, "default"); + profile = rankProfileRegistry.get(search, "default"); } else { profile = new RankProfile(name, search, rankProfileRegistry); } @@ -1865,7 +1865,7 @@ void rankProfile(Search search) : [inheritsRankProfile(profile)] lbrace() (rankProfileItem(profile) (<NL>)*)* <RBRACE> ) { - rankProfileRegistry.addRankProfile(profile); + rankProfileRegistry.add(profile); } } diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx Binary files differnew file mode 100644 index 00000000000..a86019bf53a --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax.onnx diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py new file mode 100644 index 00000000000..5d67a267706 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/mnist_sftmax_with_saving.py @@ -0,0 +1,92 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A very simple MNIST classifier. + +See extensive documentation at +https://www.tensorflow.org/get_started/mnist/beginners +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.examples.tutorials.mnist import input_data + +import tensorflow as tf + +FLAGS = None + + +def main(_): + # Import data + mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) + + # Create the model + x = tf.placeholder(tf.float32, [None, 784]) + + with tf.name_scope("layer"): + W = tf.Variable(tf.zeros([784, 10])) + b = tf.Variable(tf.zeros([10])) + y = tf.matmul(x, W) + b + + + # Define loss and optimizer + y_ = tf.placeholder(tf.float32, [None, 10]) + + # The raw formulation of cross-entropy, + # + # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), + # reduction_indices=[1])) + # + # can be numerically unstable. + # + # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw + # outputs of 'y', and then average across the batch. + cross_entropy = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) + train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) + + sess = tf.InteractiveSession() + tf.global_variables_initializer().run() + # Train + for _ in range(1000): + batch_xs, batch_ys = mnist.train.next_batch(100) + sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) + + # Test trained model + correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + print(sess.run(accuracy, feed_dict={x: mnist.test.images, + y_: mnist.test.labels})) + + # Save the model + export_path = "saved" + print('Exporting trained model to ', export_path) + builder = tf.saved_model.builder.SavedModelBuilder(export_path) + signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs = {'x':x}, outputs = {'y':y}) + builder.add_meta_graph_and_variables(sess, + [tf.saved_model.tag_constants.SERVING], + signature_def_map={'serving_default':signature}) + builder.save(as_text=True) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/saved_model.pbtxt b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/saved_model.pbtxt new file mode 100644 index 00000000000..05b0e4e0f29 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/saved_model.pbtxt @@ -0,0 +1,5039 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "ApplyGradientDescent" + input_arg { + name: "var" + type_attr: "T" + is_ref: true + } + input_arg { + name: "alpha" + type_attr: "T" + } + input_arg { + name: "delta" + type_attr: "T" + } + output_arg { + name: "out" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } + } + op { + name: "ArgMax" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dimension" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "output_type" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Assign" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + } + attr { + name: "validate_shape" + type: "bool" + default_value { + b: true + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: true + } + } + allows_uninitialized_input: true + } + op { + name: "BroadcastGradientArgs" + input_arg { + name: "s0" + type_attr: "T" + } + input_arg { + name: "s1" + type_attr: "T" + } + output_arg { + name: "r0" + type_attr: "T" + } + output_arg { + name: "r1" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Cast" + input_arg { + name: "x" + type_attr: "SrcT" + } + output_arg { + name: "y" + type_attr: "DstT" + } + attr { + name: "SrcT" + type: "type" + } + attr { + name: "DstT" + type: "type" + } + } + op { + name: "ConcatV2" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + input_arg { + name: "axis" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 2 + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "Equal" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type: DT_BOOL + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_QUINT8 + type: DT_QINT8 + type: DT_QINT32 + type: DT_STRING + type: DT_BOOL + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Fill" + input_arg { + name: "dims" + type: DT_INT32 + } + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "FloorDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Identity" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mean" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MergeV2Checkpoints" + input_arg { + name: "checkpoint_prefixes" + type: DT_STRING + } + input_arg { + name: "destination_prefix" + type: DT_STRING + } + attr { + name: "delete_old_dirs" + type: "bool" + default_value { + b: true + } + } + is_stateful: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "NoOp" + } + op { + name: "Pack" + input_arg { + name: "values" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "axis" + type: "int" + default_value { + i: 0 + } + } + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "Prod" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RealDiv" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Reshape" + input_arg { + name: "tensor" + type_attr: "T" + } + input_arg { + name: "shape" + type_attr: "Tshape" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshape" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "RestoreV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + output_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "SaveV2" + input_arg { + name: "prefix" + type: DT_STRING + } + input_arg { + name: "tensor_names" + type: DT_STRING + } + input_arg { + name: "shape_and_slices" + type: DT_STRING + } + input_arg { + name: "tensors" + type_list_attr: "dtypes" + } + attr { + name: "dtypes" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true + } + op { + name: "Shape" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "ShardedFilename" + input_arg { + name: "basename" + type: DT_STRING + } + input_arg { + name: "shard" + type: DT_INT32 + } + input_arg { + name: "num_shards" + type: DT_INT32 + } + output_arg { + name: "filename" + type: DT_STRING + } + } + op { + name: "Slice" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "begin" + type_attr: "Index" + } + input_arg { + name: "size" + type_attr: "Index" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Index" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "SoftmaxCrossEntropyWithLogits" + input_arg { + name: "features" + type_attr: "T" + } + input_arg { + name: "labels" + type_attr: "T" + } + output_arg { + name: "loss" + type_attr: "T" + } + output_arg { + name: "backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + } + op { + name: "StringJoin" + input_arg { + name: "inputs" + type: DT_STRING + number_attr: "N" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "separator" + type: "string" + default_value { + s: "" + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "Tile" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "multiples" + type_attr: "Tmultiples" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tmultiples" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "VariableV2" + output_arg { + name: "ref" + type_attr: "dtype" + is_ref: true + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true + } + op { + name: "ZerosLike" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + } + } + tags: "serve" + tensorflow_version: "1.4.1" + tensorflow_git_version: "v1.4.0-19-ga52c8d9" + } + graph_def { + node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + node { + name: "layer/zeros" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "layer/Variable" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "layer/Variable/Assign" + op: "Assign" + input: "layer/Variable" + input: "layer/zeros" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "layer/Variable/read" + op: "Identity" + input: "layer/Variable" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "layer/zeros_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 10 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "layer/Variable_1" + op: "VariableV2" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } + node { + name: "layer/Variable_1/Assign" + op: "Assign" + input: "layer/Variable_1" + input: "layer/zeros_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "layer/Variable_1/read" + op: "Identity" + input: "layer/Variable_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "layer/MatMul" + op: "MatMul" + input: "Placeholder" + input: "layer/Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "layer/add" + op: "Add" + input: "layer/MatMul" + input: "layer/Variable_1/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "Rank" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Rank_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_1" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub" + op: "Sub" + input: "Rank_1" + input: "Sub/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice/begin" + op: "Pack" + input: "Sub" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice" + op: "Slice" + input: "Shape_1" + input: "Slice/begin" + input: "Slice/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "concat/values_0" + input: "Slice" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape" + op: "Reshape" + input: "layer/add" + input: "concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Rank_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node { + name: "Shape_2" + op: "Shape" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "Sub_1/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_1" + op: "Sub" + input: "Rank_2" + input: "Sub_1/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_1/begin" + op: "Pack" + input: "Sub_1" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_1/size" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "Slice_1" + op: "Slice" + input: "Shape_2" + input: "Slice_1/begin" + input: "Slice_1/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "concat_1/values_0" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "concat_1/values_0" + input: "Slice_1" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + } + node { + name: "Reshape_1" + op: "Reshape" + input: "Placeholder_1" + input: "concat_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "SoftmaxCrossEntropyWithLogits" + op: "SoftmaxCrossEntropyWithLogits" + input: "Reshape" + input: "Reshape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Sub_2/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "Sub_2" + op: "Sub" + input: "Rank" + input: "Sub_2/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "Slice_2/begin" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Slice_2/size" + op: "Pack" + input: "Sub_2" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "Slice_2" + op: "Slice" + input: "Shape" + input: "Slice_2/begin" + input: "Slice_2/size" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Reshape_2" + op: "Reshape" + input: "SoftmaxCrossEntropyWithLogits" + input: "Slice_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean" + op: "Mean" + input: "Reshape_2" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node { + name: "gradients/Fill" + op: "Fill" + input: "gradients/Shape" + input: "gradients/Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape/shape" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Reshape" + op: "Reshape" + input: "gradients/Fill" + input: "gradients/Mean_grad/Reshape/shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Tile" + op: "Tile" + input: "gradients/Mean_grad/Reshape" + input: "gradients/Mean_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tmultiples" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Shape_1" + op: "Shape" + input: "Reshape_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Mean_grad/Shape_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node { + name: "gradients/Mean_grad/Const" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod" + op: "Prod" + input: "gradients/Mean_grad/Shape_1" + input: "gradients/Mean_grad/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Const_1" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "gradients/Mean_grad/Prod_1" + op: "Prod" + input: "gradients/Mean_grad/Shape_2" + input: "gradients/Mean_grad/Const_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/Mean_grad/Maximum/y" + op: "Const" + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "gradients/Mean_grad/Maximum" + op: "Maximum" + input: "gradients/Mean_grad/Prod_1" + input: "gradients/Mean_grad/Maximum/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/floordiv" + op: "FloorDiv" + input: "gradients/Mean_grad/Prod" + input: "gradients/Mean_grad/Maximum" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/Mean_grad/Shape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/Cast" + op: "Cast" + input: "gradients/Mean_grad/floordiv" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "gradients/Mean_grad/truediv" + op: "RealDiv" + input: "gradients/Mean_grad/Tile" + input: "gradients/Mean_grad/Cast" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_2_grad/Shape" + op: "Shape" + input: "SoftmaxCrossEntropyWithLogits" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_2_grad/Reshape" + op: "Reshape" + input: "gradients/Mean_grad/truediv" + input: "gradients/Reshape_2_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/zeros_like" + op: "ZerosLike" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: -1 + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + op: "ExpandDims" + input: "gradients/Reshape_2_grad/Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + op: "Mul" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" + input: "SoftmaxCrossEntropyWithLogits:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/Reshape_grad/Shape" + op: "Shape" + input: "layer/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/Reshape_grad/Reshape" + op: "Reshape" + input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" + input: "gradients/Reshape_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Shape" + op: "Shape" + input: "layer/MatMul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node { + name: "gradients/layer/add_grad/Shape_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 10 + } + } + } + } + node { + name: "gradients/layer/add_grad/BroadcastGradientArgs" + op: "BroadcastGradientArgs" + input: "gradients/layer/add_grad/Shape" + input: "gradients/layer/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Sum" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/layer/add_grad/BroadcastGradientArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/layer/add_grad/Reshape" + op: "Reshape" + input: "gradients/layer/add_grad/Sum" + input: "gradients/layer/add_grad/Shape" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/Sum_1" + op: "Sum" + input: "gradients/Reshape_grad/Reshape" + input: "gradients/layer/add_grad/BroadcastGradientArgs:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "gradients/layer/add_grad/Reshape_1" + op: "Reshape" + input: "gradients/layer/add_grad/Sum_1" + input: "gradients/layer/add_grad/Shape_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/layer/add_grad/Reshape" + input: "^gradients/layer/add_grad/Reshape_1" + } + node { + name: "gradients/layer/add_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/layer/add_grad/Reshape" + input: "^gradients/layer/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/add_grad/Reshape" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/add_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/layer/add_grad/Reshape_1" + input: "^gradients/layer/add_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/add_grad/Reshape_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + } + node { + name: "gradients/layer/MatMul_grad/MatMul" + op: "MatMul" + input: "gradients/layer/add_grad/tuple/control_dependency" + input: "layer/Variable/read" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: true + } + } + } + node { + name: "gradients/layer/MatMul_grad/MatMul_1" + op: "MatMul" + input: "Placeholder" + input: "gradients/layer/add_grad/tuple/control_dependency" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: true + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "gradients/layer/MatMul_grad/tuple/group_deps" + op: "NoOp" + input: "^gradients/layer/MatMul_grad/MatMul" + input: "^gradients/layer/MatMul_grad/MatMul_1" + } + node { + name: "gradients/layer/MatMul_grad/tuple/control_dependency" + op: "Identity" + input: "gradients/layer/MatMul_grad/MatMul" + input: "^gradients/layer/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/MatMul_grad/MatMul" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + } + } + node { + name: "gradients/layer/MatMul_grad/tuple/control_dependency_1" + op: "Identity" + input: "gradients/layer/MatMul_grad/MatMul_1" + input: "^gradients/layer/MatMul_grad/tuple/group_deps" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@gradients/layer/MatMul_grad/MatMul_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + } + node { + name: "GradientDescent/learning_rate" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node { + name: "GradientDescent/update_layer/Variable/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "layer/Variable" + input: "GradientDescent/learning_rate" + input: "gradients/layer/MatMul_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent/update_layer/Variable_1/ApplyGradientDescent" + op: "ApplyGradientDescent" + input: "layer/Variable_1" + input: "GradientDescent/learning_rate" + input: "gradients/layer/add_grad/tuple/control_dependency_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: false + } + } + } + node { + name: "GradientDescent" + op: "NoOp" + input: "^GradientDescent/update_layer/Variable/ApplyGradientDescent" + input: "^GradientDescent/update_layer/Variable_1/ApplyGradientDescent" + } + node { + name: "init" + op: "NoOp" + input: "^layer/Variable/Assign" + input: "^layer/Variable_1/Assign" + } + node { + name: "ArgMax/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax" + op: "ArgMax" + input: "layer/add" + input: "ArgMax/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "ArgMax_1/dimension" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "ArgMax_1" + op: "ArgMax" + input: "Placeholder_1" + input: "ArgMax_1/dimension" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_type" + value { + type: DT_INT64 + } + } + } + node { + name: "Equal" + op: "Equal" + input: "ArgMax" + input: "ArgMax_1" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Cast_1" + op: "Cast" + input: "Equal" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node { + name: "Mean_1" + op: "Mean" + input: "Cast_1" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "save/Const" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "model" + } + } + } + } + node { + name: "save/StringJoin/inputs_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp_65caff16d5244276b9828b0dab21b157/part" + } + } + } + } + node { + name: "save/StringJoin" + op: "StringJoin" + input: "save/Const" + input: "save/StringJoin/inputs_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "separator" + value { + s: "" + } + } + } + node { + name: "save/num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "save/ShardedFilename/shard" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node { + name: "save/ShardedFilename" + op: "ShardedFilename" + input: "save/StringJoin" + input: "save/ShardedFilename/shard" + input: "save/num_shards" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/SaveV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "layer/Variable" + string_val: "layer/Variable_1" + } + } + } + } + node { + name: "save/SaveV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node { + name: "save/SaveV2" + op: "SaveV2" + input: "save/ShardedFilename" + input: "save/SaveV2/tensor_names" + input: "save/SaveV2/shape_and_slices" + input: "layer/Variable" + input: "layer/Variable_1" + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + } + node { + name: "save/control_dependency" + op: "Identity" + input: "save/ShardedFilename" + input: "^save/SaveV2" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@save/ShardedFilename" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "save/ShardedFilename" + input: "^save/control_dependency" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node { + name: "save/MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "save/MergeV2Checkpoints/checkpoint_prefixes" + input: "save/Const" + attr { + key: "delete_old_dirs" + value { + b: true + } + } + } + node { + name: "save/Identity" + op: "Identity" + input: "save/Const" + input: "^save/control_dependency" + input: "^save/MergeV2Checkpoints" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node { + name: "save/RestoreV2/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "layer/Variable" + } + } + } + } + node { + name: "save/RestoreV2/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2/tensor_names" + input: "save/RestoreV2/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign" + op: "Assign" + input: "layer/Variable" + input: "save/RestoreV2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 784 + } + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/RestoreV2_1/tensor_names" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "layer/Variable_1" + } + } + } + } + node { + name: "save/RestoreV2_1/shape_and_slices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } + } + node { + name: "save/RestoreV2_1" + op: "RestoreV2" + input: "save/Const" + input: "save/RestoreV2_1/tensor_names" + input: "save/RestoreV2_1/shape_and_slices" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + } + } + } + } + node { + name: "save/Assign_1" + op: "Assign" + input: "layer/Variable_1" + input: "save/RestoreV2_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@layer/Variable_1" + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } + attr { + key: "use_locking" + value { + b: true + } + } + attr { + key: "validate_shape" + value { + b: true + } + } + } + node { + name: "save/restore_shard" + op: "NoOp" + input: "^save/Assign" + input: "^save/Assign_1" + } + node { + name: "save/restore_all" + op: "NoOp" + input: "^save/restore_shard" + } + versions { + producer: 24 + } + } + saver_def { + filename_tensor_name: "save/Const:0" + save_tensor_name: "save/Identity:0" + restore_op_name: "save/restore_all" + max_to_keep: 5 + sharded: true + keep_checkpoint_every_n_hours: 10000.0 + version: V2 + } + collection_def { + key: "train_op" + value { + node_list { + value: "GradientDescent" + } + } + } + collection_def { + key: "trainable_variables" + value { + bytes_list { + value: "\n\020layer/Variable:0\022\025layer/Variable/Assign\032\025layer/Variable/read:02\rlayer/zeros:0" + value: "\n\022layer/Variable_1:0\022\027layer/Variable_1/Assign\032\027layer/Variable_1/read:02\017layer/zeros_1:0" + } + } + } + collection_def { + key: "variables" + value { + bytes_list { + value: "\n\020layer/Variable:0\022\025layer/Variable/Assign\032\025layer/Variable/read:02\rlayer/zeros:0" + value: "\n\022layer/Variable_1:0\022\027layer/Variable_1/Assign\032\027layer/Variable_1/read:02\017layer/zeros_1:0" + } + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "x" + value { + name: "Placeholder:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + } + } + outputs { + key: "y" + value { + name: "layer/add:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 10 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 00000000000..826b0280abf --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 diff --git a/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/variables/variables.index b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/variables/variables.index Binary files differnew file mode 100644 index 00000000000..d00fc5b06ed --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/mnist_softmax/saved/variables/variables.index diff --git a/config-model/src/test/cfg/application/ml_serving/models/xgboost.2.2.json b/config-model/src/test/cfg/application/ml_serving/models/xgboost.2.2.json new file mode 100644 index 00000000000..f8949b47e52 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/xgboost.2.2.json @@ -0,0 +1,19 @@ +[ + { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 1.71218 }, + { "nodeid": 4, "leaf": -1.70044 } + ]}, + { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [ + { "nodeid": 5, "leaf": -1.94071 }, + { "nodeid": 6, "leaf": 1.85965 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 0.784718 }, + { "nodeid": 4, "leaf": -0.96853 } + ]}, + { "nodeid": 2, "leaf": -6.23624 } + ]} +]
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_serving/services.xml b/config-model/src/test/cfg/application/ml_serving/services.xml new file mode 100644 index 00000000000..42528336bc5 --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/services.xml @@ -0,0 +1,12 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + <nodes> + <node hostalias="node1" /> + </nodes> + + </container> + +</services> diff --git a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java index 643a3bd0b91..ded8d88aa99 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java @@ -18,7 +18,9 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.UnproperSearch; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionKey; +import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.search.SearchDefinition; import org.junit.After; import org.junit.Rule; @@ -35,9 +37,11 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.jar.JarEntry; import java.util.jar.JarFile; import java.util.regex.Pattern; +import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.is; @@ -118,6 +122,21 @@ public class ApplicationDeployTest { } @Test + public void testMl_ServingApplication() throws SAXException, IOException { + FilesApplicationPackage app = createAppPkg(TESTDIR + "ml_serving"); + VespaModel model = new VespaModel(app); + ContainerCluster cluster = model.getContainerClusters().get("container"); + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); + cluster.getConfig(b); + RankProfilesConfig config = new RankProfilesConfig(b); + assertEquals(3, config.rankprofile().size()); + Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); + assertTrue(modelNames.contains("xgboost_2_2_json")); + assertTrue(modelNames.contains("mnist_softmax_onnx")); + assertTrue(modelNames.contains("mnist_softmax_saved")); + } + + @Test public void testGetFile() throws IOException { FilesApplicationPackage app = createAppPkg(TESTDIR + "app1"); try (Reader foo = app.getFile(Path.fromString("files/foo.json")).createReader()) { @@ -179,8 +198,9 @@ public class ApplicationDeployTest { @Test public void non_existent_include_dir_is_not_allowed() throws Exception { File appDir = tmpFolder.newFolder("non-existent-include"); - String services = "<services version='1.0'>" + - "<include dir='non-existent' />" + + String services = + "<services version='1.0'>" + + " <include dir='non-existent' />" + "</services>\n"; IOUtils.writeFile(new File(appDir, "services.xml"), services, false); @@ -197,11 +217,11 @@ public class ApplicationDeployTest { File tmpDir = tmpFolder.getRoot(); IOUtils.copyDirectory(new File(TESTDIR, "app1"), tmpDir); FilesApplicationPackage app = createAppPkg(tmpDir.getAbsolutePath()); - assertThat(getSearchDefinitions(app).size(), is(5)); + assertEquals(5, getSearchDefinitions(app).size()); File sdDir = new File(tmpDir, "searchdefinitions"); File sd = new File(sdDir, "testfoo.sd"); IOUtils.writeFile(sd, "search testfoo { document testfoo { field bar type string { } } }", false); - assertThat(getSearchDefinitions(app).size(), is(6)); + assertEquals(6, getSearchDefinitions(app).size()); } @Test @@ -293,7 +313,7 @@ public class ApplicationDeployTest { String appName = "src/test/cfg//application/app1"; FilesApplicationPackage app = FilesApplicationPackage.fromFile(new File(appName), false); Map<ConfigDefinitionKey, UnparsedConfigDefinition> defs = app.getAllExistingConfigDefs(); - assertThat(defs.size(), is(5)); + assertEquals(5, defs.size()); } @Test diff --git a/config-model/src/test/java/com/yahoo/config/model/ConfigModelContextTest.java b/config-model/src/test/java/com/yahoo/config/model/ConfigModelContextTest.java index aa17ea894bf..cb51cf23830 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ConfigModelContextTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ConfigModelContextTest.java @@ -9,15 +9,12 @@ import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.config.model.test.MockRoot; import org.junit.Test; -import java.util.Optional; - import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.junit.Assert.assertThat; /** - * @author lulf - * @since 5.1 + * @author Ulf Lilleengen */ public class ConfigModelContextTest { @@ -30,12 +27,12 @@ public class ConfigModelContextTest { .build(); DeployState deployState = DeployState.createTestState(pkg); DeployLogger logger = deployState.getDeployLogger(); - ConfigModelContext ctx = ConfigModelContext.create(deployState, null, root, id); + ConfigModelContext ctx = ConfigModelContext.create(deployState, null, null, root, id); assertThat(ctx.getApplicationPackage(), is(pkg)); assertThat(ctx.getProducerId(), is(id)); assertThat(ctx.getParentProducer(), is(root)); assertThat(ctx.getDeployLogger(), is(logger)); - ctx = ConfigModelContext.create(null, root, id); + ctx = ConfigModelContext.create(null, null, root, id); assertThat(ctx.getProducerId(), is(id)); assertThat(ctx.getParentProducer(), is(root)); AbstractConfigProducer newRoot = new MockRoot("bar"); diff --git a/config-model/src/test/java/com/yahoo/config/model/graph/ModelGraphTest.java b/config-model/src/test/java/com/yahoo/config/model/graph/ModelGraphTest.java index 34da6e588be..85c5f5ece45 100644 --- a/config-model/src/test/java/com/yahoo/config/model/graph/ModelGraphTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/graph/ModelGraphTest.java @@ -64,10 +64,10 @@ public class ModelGraphTest { ModelGraph graph = new ModelGraphBuilder().addBuilder(new GraphMock.BC()).addBuilder(new GraphMock.BB()).addBuilder(new GraphMock.BA()).build(); List<ModelNode> nodes = graph.topologicalSort(); MockRoot root = new MockRoot(); - GraphMock.A a = (GraphMock.A) nodes.get(0).createModel(ConfigModelContext.create(null, root, "first")); - GraphMock.B b = (GraphMock.B) nodes.get(1).createModel(ConfigModelContext.create(null, root, "second")); - GraphMock.B b2 = (GraphMock.B) nodes.get(1).createModel(ConfigModelContext.create(null, root, "second2")); - GraphMock.C c = (GraphMock.C) nodes.get(2).createModel(ConfigModelContext.create(null, root, "third")); + GraphMock.A a = (GraphMock.A) nodes.get(0).createModel(ConfigModelContext.create(null, null, root, "first")); + GraphMock.B b = (GraphMock.B) nodes.get(1).createModel(ConfigModelContext.create(null, null, root, "second")); + GraphMock.B b2 = (GraphMock.B) nodes.get(1).createModel(ConfigModelContext.create(null, null, root, "second2")); + GraphMock.C c = (GraphMock.C) nodes.get(2).createModel(ConfigModelContext.create(null, null, root, "third")); assertNotNull(a); assertNotNull(b); assertNotNull(b2); @@ -91,7 +91,7 @@ public class ModelGraphTest { expectedEx.expect(IllegalArgumentException.class); expectedEx.expectMessage("Constructor for " + GraphMock.Bad.class.getName() + " must have as its first argument a " + ConfigModelContext.class.getName()); ModelNode node = new ModelNode(new GraphMock.Bad.Builder()); - node.createModel(ConfigModelContext.create(null, new MockRoot(), "foo")); + node.createModel(ConfigModelContext.create(null, null, new MockRoot(), "foo")); } @Test @@ -99,7 +99,7 @@ public class ModelGraphTest { expectedEx.expect(IllegalArgumentException.class); expectedEx.expectMessage("Unable to find constructor argument class java.lang.String for com.yahoo.config.model.graph.GraphMock$Bad2"); ModelNode node = new ModelNode(new GraphMock.Bad2.Builder()); - node.createModel(ConfigModelContext.create(null, new MockRoot(), "foo")); + node.createModel(ConfigModelContext.create(null, null, new MockRoot(), "foo")); } @Test @@ -107,8 +107,8 @@ public class ModelGraphTest { ModelGraph graph = new ModelGraphBuilder().addBuilder(new GraphMock.BC()).addBuilder(new GraphMock.BA()).build(); List<ModelNode> nodes = graph.topologicalSort(); MockRoot root = new MockRoot(); - GraphMock.A a = (GraphMock.A) nodes.get(0).createModel(ConfigModelContext.create(null, root, "first")); - GraphMock.C c = (GraphMock.C) nodes.get(1).createModel(ConfigModelContext.create(null, root, "second")); + GraphMock.A a = (GraphMock.A) nodes.get(0).createModel(ConfigModelContext.create(null, null, root, "first")); + GraphMock.C c = (GraphMock.C) nodes.get(1).createModel(ConfigModelContext.create(null, null, root, "second")); assertThat(c.a, is(a)); assertTrue(c.b.isEmpty()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/DiversityTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/DiversityTestCase.java index edc0462c6ee..e20bc4d96aa 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/DiversityTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/DiversityTestCase.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.search.query.ranking.Diversity; import com.yahoo.searchdefinition.parser.ParseException; import org.junit.Test; @@ -45,7 +44,7 @@ public class DiversityTestCase { "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile.MatchPhaseSettings matchPhase = rankProfileRegistry.getRankProfile(s, "parent").getMatchPhaseSettings(); + RankProfile.MatchPhaseSettings matchPhase = rankProfileRegistry.get(s, "parent").getMatchPhaseSettings(); RankProfile.DiversitySettings diversity = matchPhase.getDiversity(); assertEquals("b", diversity.getAttribute()); assertEquals(74, diversity.getMinGroups()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileRegistryTest.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileRegistryTest.java index 6999518e706..28559f351ac 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileRegistryTest.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileRegistryTest.java @@ -9,10 +9,9 @@ import org.junit.Test; import java.io.File; -import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertThat; /** * @author Ulf Lilleengen @@ -25,8 +24,8 @@ public class RankProfileRegistryTest { TestRoot root = new TestDriver().buildModel(FilesApplicationPackage.fromFile(new File(TESTDIR))); RankProfilesConfig left = root.getConfig(RankProfilesConfig.class, "inherit/search/cluster.inherit/left"); RankProfilesConfig right = root.getConfig(RankProfilesConfig.class, "inherit/search/cluster.inherit/right"); - assertThat(left.rankprofile().size(), is(3)); - assertThat(right.rankprofile().size(), is(2)); + assertEquals(3, left.rankprofile().size()); + assertEquals(2, right.rankprofile().size()); } @Test(expected = IllegalArgumentException.class) @@ -34,8 +33,8 @@ public class RankProfileRegistryTest { Search search = new Search("foo", null); RankProfileRegistry rankProfileRegistry = RankProfileRegistry.createRankProfileRegistryWithBuiltinRankProfiles(search); RankProfile barRankProfile = new RankProfile("bar", search, rankProfileRegistry); - rankProfileRegistry.addRankProfile(barRankProfile); - rankProfileRegistry.addRankProfile(barRankProfile); + rankProfileRegistry.add(barRankProfile); + rankProfileRegistry.add(barRankProfile); } @Test @@ -44,11 +43,11 @@ public class RankProfileRegistryTest { RankProfileRegistry rankProfileRegistry = RankProfileRegistry.createRankProfileRegistryWithBuiltinRankProfiles(search); for (String rankProfileName : RankProfileRegistry.overridableRankProfileNames) { - assertNull(rankProfileRegistry.getRankProfile(search, rankProfileName).getMacros().get("foo")); + assertNull(rankProfileRegistry.get(search, rankProfileName).getMacros().get("foo")); RankProfile rankProfileWithAddedMacro = new RankProfile(rankProfileName, search, rankProfileRegistry); rankProfileWithAddedMacro.addMacro("foo", true); - rankProfileRegistry.addRankProfile(rankProfileWithAddedMacro); - assertNotNull(rankProfileRegistry.getRankProfile(search, rankProfileName).getMacros().get("foo")); + rankProfileRegistry.add(rankProfileWithAddedMacro); + assertNotNull(rankProfileRegistry.get(search, rankProfileName).getMacros().get("foo")); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index de9df08f5c0..4df3add13c5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -17,7 +17,6 @@ import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; import java.util.Iterator; @@ -46,7 +45,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { search.addDocument(document); RankProfile child = new RankProfile("child", search, rankProfileRegistry); child.setInherited("default"); - rankProfileRegistry.addRankProfile(child); + rankProfileRegistry.add(child); Iterator<RankProfile.RankSetting> i = child.rankSettingIterator(); @@ -83,8 +82,8 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { builder.build(); Search search = builder.getSearch(); AttributeFields attributeFields = new AttributeFields(search); - verifyRankProfile(rankProfileRegistry.getRankProfile(search, "parent"), attributeFields); - verifyRankProfile(rankProfileRegistry.getRankProfile(search, "child"), attributeFields); + verifyRankProfile(rankProfileRegistry.get(search, "parent"), attributeFields); + verifyRankProfile(rankProfileRegistry.get(search, "child"), attributeFields); } private void verifyRankProfile(RankProfile rankProfile, AttributeFields attributeFields) { @@ -119,11 +118,11 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { builder.build(); Search search = builder.getSearch(); - assertEquals(4, registry.allRankProfiles().size()); - assertAttributeTypeSettings(registry.getRankProfile(search, "default"), search); - assertAttributeTypeSettings(registry.getRankProfile(search, "unranked"), search); - assertAttributeTypeSettings(registry.getRankProfile(search, "p1"), search); - assertAttributeTypeSettings(registry.getRankProfile(search, "p2"), search); + assertEquals(4, registry.all().size()); + assertAttributeTypeSettings(registry.get(search, "default"), search); + assertAttributeTypeSettings(registry.get(search, "unranked"), search); + assertAttributeTypeSettings(registry.get(search, "p1"), search); + assertAttributeTypeSettings(registry.get(search, "p2"), search); } private static void assertAttributeTypeSettings(RankProfile profile, Search search) { @@ -145,11 +144,11 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { builder.build(true, new BaseDeployLogger()); Search search = builder.getSearch(); - assertEquals(4, registry.allRankProfiles().size()); - assertQueryFeatureTypeSettings(registry.getRankProfile(search, "default"), search); - assertQueryFeatureTypeSettings(registry.getRankProfile(search, "unranked"), search); - assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p1"), search); - assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search); + assertEquals(4, registry.all().size()); + assertQueryFeatureTypeSettings(registry.get(search, "default"), search); + assertQueryFeatureTypeSettings(registry.get(search, "unranked"), search); + assertQueryFeatureTypeSettings(registry.get(search, "p1"), search); + assertQueryFeatureTypeSettings(registry.get(search, "p2"), search); } private static QueryProfileRegistry setupQueryProfileTypes() { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java index 3a2482b56d0..8df3985fd24 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java @@ -1,8 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import com.yahoo.collections.Pair; -import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; @@ -10,10 +8,6 @@ import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - import static org.junit.Assert.assertEquals; /** @@ -57,7 +51,7 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { { // Check declared model - RankProfile parent = rankProfileRegistry.getRankProfile(search, "parent"); + RankProfile parent = rankProfileRegistry.get(search, "parent"); assertEquals("query(a) = 1500", parent.getRankProperties().get(0).toString()); // Check derived model @@ -67,11 +61,11 @@ public class RankPropertiesTestCase extends SearchDefinitionTestCase { { // Check declared model - RankProfile parent = rankProfileRegistry.getRankProfile(search, "child"); + RankProfile parent = rankProfileRegistry.get(search, "child"); assertEquals("query(a) = 2000", parent.getRankProperties().get(0).toString()); // Check derived model - RawRankProfile rawChild = new RawRankProfile(rankProfileRegistry.getRankProfile(search, "child"), + RawRankProfile rawChild = new RawRankProfile(rankProfileRegistry.get(search, "child"), new QueryProfileRegistry(), new ImportedModels(), attributeFields); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java index 9e1fdd258e2..fd69e282fcb 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingConstantTest.java @@ -44,7 +44,7 @@ public class RankingConstantTest { searchBuilder.build(); Search search = searchBuilder.getSearch(); - Iterator<RankingConstant> constantIterator = search.getRankingConstants().values().iterator(); + Iterator<RankingConstant> constantIterator = search.rankingConstants().asMap().values().iterator(); RankingConstant constant = constantIterator.next(); assertEquals(TENSOR_NAME, constant.getName()); assertEquals(TENSOR_FILE, constant.getFileName()); @@ -101,7 +101,7 @@ public class RankingConstantTest { )); searchBuilder.build(); Search search = searchBuilder.getSearch(); - RankingConstant constant = search.getRankingConstants().values().iterator().next(); + RankingConstant constant = search.rankingConstants().asMap().values().iterator().next(); assertEquals("simplename", constant.getFileName()); } @@ -120,7 +120,7 @@ public class RankingConstantTest { )); searchBuilder.build(); Search search = searchBuilder.getSearch(); - RankingConstant constant = search.getRankingConstants().values().iterator().next(); + RankingConstant constant = search.rankingConstants().asMap().values().iterator().next(); assertEquals(RankingConstant.PathType.URI, constant.getPathType()); assertEquals("http://somewhere.far.away/in/another-galaxy", constant.getUri()); } @@ -140,7 +140,7 @@ public class RankingConstantTest { )); searchBuilder.build(); Search search = searchBuilder.getSearch(); - RankingConstant constant = search.getRankingConstants().values().iterator().next(); + RankingConstant constant = search.rankingConstants().asMap().values().iterator().next(); assertEquals(RankingConstant.PathType.URI, constant.getPathType()); assertEquals("https://somewhere.far.away:4443/in/another-galaxy", constant.getUri()); } @@ -160,7 +160,7 @@ public class RankingConstantTest { )); searchBuilder.build(); Search search = searchBuilder.getSearch(); - RankingConstant constant = search.getRankingConstants().values().iterator().next(); + RankingConstant constant = search.rankingConstants().asMap().values().iterator().next(); assertEquals(RankingConstant.PathType.URI, constant.getPathType()); assertEquals("http://somewhere.far.away:4080/in/another-galaxy", constant.getUri()); } @@ -180,7 +180,7 @@ public class RankingConstantTest { )); searchBuilder.build(); Search search = searchBuilder.getSearch(); - RankingConstant constant = search.getRankingConstants().values().iterator().next(); + RankingConstant constant = search.rankingConstants().asMap().values().iterator().next(); assertEquals(RankingConstant.PathType.URI, constant.getPathType()); assertEquals("http:somewhere.far.away/in/another-galaxy", constant.getUri()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index da546967dc1..a524a26cbef 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -2,7 +2,6 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; -import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.yolean.Exceptions; @@ -11,9 +10,7 @@ import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; import org.junit.Test; -import java.util.ArrayList; import java.util.List; -import java.util.Map; import static org.junit.Assert.*; import static org.junit.Assert.assertEquals; @@ -70,14 +67,14 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(queryProfileRegistry, new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(queryProfileRegistry, new ImportedModels()); assertEquals("0.0", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child1 = rankProfileRegistry.getRankProfile(s, "child1").compile(queryProfileRegistry, new ImportedModels()); + RankProfile child1 = rankProfileRegistry.get(s, "child1").compile(queryProfileRegistry, new ImportedModels()); assertEquals("6.5", child1.getFirstPhaseRanking().getRoot().toString()); assertEquals("11.5", child1.getSecondPhaseRanking().getRoot().toString()); - RankProfile child2 = rankProfileRegistry.getRankProfile(s, "child2").compile(queryProfileRegistry, new ImportedModels()); + RankProfile child2 = rankProfileRegistry.get(s, "child2").compile(queryProfileRegistry, new ImportedModels()); assertEquals("16.6", child2.getFirstPhaseRanking().getRoot().toString()); assertEquals("foo: 14.0", child2.getMacros().get("foo").getRankingExpression().toString()); List<Pair<String, String>> rankProperties = new RawRankProfile(child2, @@ -113,7 +110,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); try { - rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); fail("Should have caused an exception"); } catch (IllegalArgumentException e) { @@ -143,7 +140,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile profile = rankProfileRegistry.getRankProfile(s, "test"); + RankProfile profile = rankProfileRegistry.get(s, "test"); profile.parseExpressions(); // TODO: Do differently assertEquals("safeLog(popShareSlowDecaySignal,-9.21034037)", profile.getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString()); } @@ -172,7 +169,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile profile = rankProfileRegistry.getRankProfile(s, "test"); + RankProfile profile = rankProfileRegistry.get(s, "test"); profile.parseExpressions(); // TODO: Do differently assertEquals("safeLog(popShareSlowDecaySignal,myValue)", profile.getMacros().get("POP_SLOW_SCORE").getRankingExpression().getRoot().toString()); assertEquals("safeLog(popShareSlowDecaySignal,-9.21034037)", @@ -197,7 +194,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile profile = rankProfileRegistry.getRankProfile(s, "test"); + RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("k1 + (k2 + k3) / 100000000.0", profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("rank_default").getRankingExpression().getRoot().toString()); } @@ -223,7 +220,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile profile = rankProfileRegistry.getRankProfile(s, "test"); + RankProfile profile = rankProfileRegistry.get(s, "test"); assertEquals("0.5 + 50 * (attribute(rating_yelp) - 3)", profile.compile(new QueryProfileRegistry(), new ImportedModels()).getMacros().get("rank_default").getRankingExpression().getRoot().toString()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index 555aa698c65..e1ddd0c02ca 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -62,10 +62,10 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))", parent.getFirstPhaseRanking().getRoot().toString()); - RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("7.0 * (9 + attribute(a))", child.getFirstPhaseRanking().getRoot().toString()); } @@ -122,14 +122,14 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile parent = rankProfileRegistry.getRankProfile(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("17.0", parent.getFirstPhaseRanking().getRoot().toString()); assertEquals("0.0", parent.getSecondPhaseRanking().getRoot().toString()); assertEquals("10.0", getRankingExpression("foo", parent, s)); assertEquals("17.0", getRankingExpression("firstphase", parent, s)); assertEquals("0.0", getRankingExpression("secondphase", parent, s)); - RankProfile child = rankProfileRegistry.getRankProfile(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("31.0 + bar + arg(4.0)", child.getFirstPhaseRanking().getRoot().toString()); assertEquals("24.0", child.getSecondPhaseRanking().getRoot().toString()); assertEquals("12.0", getRankingExpression("foo", child, s)); @@ -178,7 +178,7 @@ public class RankingExpressionInliningTestCase extends SearchDefinitionTestCase builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); assertEquals("attribute(a) + C + (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString()); assertEquals("attribute(a) + attribute(b)", getRankingExpression("C", test, s)); assertEquals("attribute(b) + 1", getRankingExpression("D", test, s)); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 2bd7d3031a5..1ece2355a92 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -10,12 +10,9 @@ import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import org.junit.Ignore; import org.junit.Test; -import java.util.ArrayList; import java.util.List; -import java.util.Map; import static org.junit.Assert.assertEquals; @@ -45,7 +42,7 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), new ImportedModels(), @@ -89,7 +86,7 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), new ImportedModels(), @@ -139,7 +136,7 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, new QueryProfileRegistry(), new ImportedModels(), @@ -203,7 +200,7 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles, new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, new ImportedModels(), diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/SearchImporterTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/SearchImporterTestCase.java index ace70e69959..8cdfdd51637 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/SearchImporterTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/SearchImporterTestCase.java @@ -131,20 +131,20 @@ public class SearchImporterTestCase extends SearchDefinitionTestCase { assertEquals(Attribute.CollectionType.ARRAY, attribute.getCollectionType()); // Rank Profiles - RankProfile profile=rankProfileRegistry.getRankProfile(search, "default"); + RankProfile profile=rankProfileRegistry.get(search, "default"); assertNotNull(profile); assertNull(profile.getInheritedName()); assertEquals(null,profile.getDeclaredRankSetting("measurement", RankProfile.RankSetting.Type.RANKTYPE)); assertEquals(RankType.EMPTY, profile.getRankSetting("measurement", RankProfile.RankSetting.Type.RANKTYPE).getValue()); - profile=rankProfileRegistry.getRankProfile(search, "experimental"); + profile=rankProfileRegistry.get(search, "experimental"); assertNotNull(profile); assertEquals("default",profile.getInheritedName()); assertEquals(RankType.IDENTITY, profile.getDeclaredRankSetting("measurement", RankProfile.RankSetting.Type.RANKTYPE).getValue()); - profile=rankProfileRegistry.getRankProfile(search, "other"); + profile=rankProfileRegistry.get(search, "other"); assertNotNull(profile); assertEquals("experimental",profile.getInheritedName()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java index dec4b734f27..94d0bf6329a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java @@ -38,7 +38,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { field1.parseIndexingScript("{ index }"); field1.setLiteralBoost(20); RankProfile other=new RankProfile("other", search, rankProfileRegistry); - rankProfileRegistry.addRankProfile(other); + rankProfileRegistry.add(other); other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); Processing.process(search, new BaseDeployLogger(), rankProfileRegistry, new QueryProfiles(), true); @@ -69,7 +69,7 @@ public class LiteralBoostTestCase extends AbstractExportingTestCase { SDField field1= document.addField("a", DataType.STRING); field1.parseIndexingScript("{ index }"); RankProfile other=new RankProfile("other", search, rankProfileRegistry); - rankProfileRegistry.addRankProfile(other); + rankProfileRegistry.add(other); other.addRankSetting(new RankProfile.RankSetting("a", RankProfile.RankSetting.Type.LITERALBOOST, 333)); search = SearchBuilder.buildFromRawSearch(search, rankProfileRegistry, new QueryProfileRegistry()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 45cdbfa9c1f..9bbc1347aeb 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -82,14 +82,14 @@ class RankProfileSearchFixture { } public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { - RankProfile compiled = rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); + RankProfile compiled = rankProfileRegistry.get(search, rankProfile).compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); compiledRankProfiles.put(rankProfile, compiled); return compiled; } /** Returns the given uncompiled profile */ public RankProfile rankProfile(String rankProfile) { - return rankProfileRegistry.getRankProfile(search, rankProfile); + return rankProfileRegistry.get(search, rankProfile); } /** Returns the given compiled profile, or null if not compiled yet or not present at all */ diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankPropertyVariablesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankPropertyVariablesTestCase.java index df2bcca63dd..d740884d3e5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankPropertyVariablesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankPropertyVariablesTestCase.java @@ -25,12 +25,12 @@ public class RankPropertyVariablesTestCase extends SearchDefinitionTestCase { new BaseDeployLogger(), rankProfileRegistry, new QueryProfileRegistry()); - assertRankPropEquals(rankProfileRegistry.getRankProfile(search, "other").getRankProperties(), "$testvar1", "foo"); - assertRankPropEquals(rankProfileRegistry.getRankProfile(search, "other").getRankProperties(), "$testvar_2", "bar"); - assertRankPropEquals(rankProfileRegistry.getRankProfile(search, "other").getRankProperties(), "$testvarOne23", "baz"); - assertRankPropEquals(rankProfileRegistry.getRankProfile(search, "another").getRankProperties(), "$Testvar1", "1"); - assertRankPropEquals(rankProfileRegistry.getRankProfile(search, "another").getRankProperties(), "$Testvar_4", "4"); - assertRankPropEquals(rankProfileRegistry.getRankProfile(search, "another").getRankProperties(), "$testvarFour23", "234234.234"); + assertRankPropEquals(rankProfileRegistry.get(search, "other").getRankProperties(), "$testvar1", "foo"); + assertRankPropEquals(rankProfileRegistry.get(search, "other").getRankProperties(), "$testvar_2", "bar"); + assertRankPropEquals(rankProfileRegistry.get(search, "other").getRankProperties(), "$testvarOne23", "baz"); + assertRankPropEquals(rankProfileRegistry.get(search, "another").getRankProperties(), "$Testvar1", "1"); + assertRankPropEquals(rankProfileRegistry.get(search, "another").getRankProperties(), "$Testvar_4", "4"); + assertRankPropEquals(rankProfileRegistry.get(search, "another").getRankProperties(), "$testvarFour23", "234234.234"); } private void assertRankPropEquals(List<RankProperty> props, String key, String val) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java index 61baee6da26..d8eb4368b57 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java @@ -3,12 +3,10 @@ package com.yahoo.searchdefinition.processing; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.yolean.Exceptions; -import org.junit.Ignore; import org.junit.Test; import java.util.Map; @@ -137,7 +135,7 @@ public class RankingExpressionTypeValidatorTestCase { )); builder.build(); RankProfile profile = - builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile"); + builder.getRankProfileRegistry().get(builder.getSearch(), "my_rank_profile"); assertEquals(TensorType.fromSpec("tensor(x[],y[])"), summaryFeatures(profile).get("macro1(a)").type(profile.typeContext(builder.getQueryProfileRegistry()))); assertEquals(TensorType.fromSpec("tensor(z[10])"), @@ -179,7 +177,7 @@ public class RankingExpressionTypeValidatorTestCase { )); builder.build(); RankProfile profile = - builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile"); + builder.getRankProfileRegistry().get(builder.getSearch(), "my_rank_profile"); assertEquals(TensorType.fromSpec("tensor(x[],y[])"), summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry()))); assertEquals(TensorType.fromSpec("tensor(z[10])"), diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 77d20657f64..4db5f312cae 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -30,8 +30,8 @@ public class RankingExpressionWithOnnxTestCase { private final Path applicationDir = Path.fromString("src/test/integration/onnx/"); - /** The model name - an artifact of the fact that the model here is not placed in the expected directory (models) */ - private final static String name = "test_integration_onnx_models_mnist_softmax_onnx"; + /** The model name */ + private final static String name = "mnist_softmax_onnx"; private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; @@ -168,9 +168,9 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx','y'): " + - "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax.onnx.default.add", - Exceptions.toMessageString(expected)); + "onnx('mnist_softmax.onnx','y'): " + + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax_onnx.default.add", + Exceptions.toMessageString(expected)); } } @@ -230,7 +230,7 @@ public class RankingExpressionWithOnnxTestCase { search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", - search.search().getRankingConstants().get( name + "_Variable")); + search.search().rankingConstants().get( name + "_Variable")); assertLargeConstant( name + "_Variable_1", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -244,7 +244,7 @@ public class RankingExpressionWithOnnxTestCase { searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); assertNull("Constant overridden by macro is not added", - searchFromStored.search().getRankingConstants().get( name + "_Variable")); + searchFromStored.search().rankingConstants().get( name + "_Variable")); assertLargeConstant( name + "_Variable_1", searchFromStored, Optional.of(10L)); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -258,7 +258,7 @@ public class RankingExpressionWithOnnxTestCase { private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { try { Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax.onnx/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + RankingConstant rankingConstant = search.search().rankingConstants().get(name); assertEquals(name, rankingConstant.getName()); assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index cf37864b73a..a212726efda 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -45,8 +45,8 @@ public class RankingExpressionWithTensorFlowTestCase { private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); - /** The model name - an artifact of the fact that the model here is not placed in the expected directory (models) */ - private final String name = "test_integration_tensorflow_models_mnist_softmax_saved"; + /** The model name */ + private final String name = "mnist_softmax_saved"; private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_layer_Variable_1_read), f(a,b)(a + b))"; @@ -286,7 +286,7 @@ public class RankingExpressionWithTensorFlowTestCase { search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", - search.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); + search.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read")); assertLargeConstant(name + "_layer_Variable_1_read", search, Optional.of(10L)); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -302,7 +302,7 @@ public class RankingExpressionWithTensorFlowTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", - searchFromStored.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); + searchFromStored.search().rankingConstants().get("mnist_softmax_saved_layer_Variable_read")); assertLargeConstant(name + "_layer_Variable_1_read", searchFromStored, Optional.of(10L)); } finally { @@ -322,7 +322,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMacroGeneration() { - final String name = "test_integration_tensorflow_models_mnist_saved"; + final String name = "mnist_saved"; final String expression = "join(join(reduce(join(join(join(imported_ml_macro_" + name + "_dnn_hidden2_add, reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_" + name + "_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; @@ -340,7 +340,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException { - final String name = "test_integration_tensorflow_models_mnist_saved"; + final String name = "mnist_saved"; final String rankProfiles = " rank-profile my_profile {\n" + " macro input() {\n" + @@ -404,7 +404,7 @@ public class RankingExpressionWithTensorFlowTestCase { private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { try { Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + RankingConstant rankingConstant = search.search().rankingConstants().get(name); assertEquals(name, rankingConstant.getName()); assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index 31ceb97ab50..86127a260c5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -12,7 +12,6 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import org.junit.Test; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -26,7 +25,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressionfunction", rankProfileRegistry, new QueryProfileRegistry()).getSearch(); - final RankProfile macrosRankProfile = rankProfileRegistry.getRankProfile(search, "macros"); + final RankProfile macrosRankProfile = rankProfileRegistry.get(search, "macros"); macrosRankProfile.parseExpressions(); final Map<String, RankProfile.Macro> macros = macrosRankProfile.getMacros(); assertEquals(2, macros.get("titlematch$").getFormalParams().size()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 6b287c77a10..76c50821cb9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -200,7 +200,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { "}\n"); builder.build(true, new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles, new ImportedModels()); + RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedModels()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, queryProfiles, new ImportedModels(), diff --git a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/LegacyConfigModelBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/LegacyConfigModelBuilderTest.java index 8070c4ef1bd..76182076ee5 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/LegacyConfigModelBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/LegacyConfigModelBuilderTest.java @@ -27,7 +27,7 @@ public class LegacyConfigModelBuilderTest { String services = "<foo><config name=\"bar\"><key>value</key></config></foo>"; ModelBuilder builder = new ModelBuilder(); Model model = builder.build(DeployState.createTestState(new MockApplicationPackage.Builder().withServices(services).build()), - null, new MockRoot(), XML.getDocument(services).getDocumentElement()); + null, null, new MockRoot(), XML.getDocument(services).getDocumentElement()); assertThat(model.getContext().getParentProducer().getUserConfigs().size(), is(1)); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java index d9c151480fe..850fd91e151 100755 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerClusterTest.java @@ -16,6 +16,7 @@ import com.yahoo.config.provision.Zone; import com.yahoo.container.handler.ThreadpoolConfig; import com.yahoo.container.jdisc.config.MetricDefaultsConfig; import com.yahoo.search.config.QrStartConfig; +import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.vespa.model.Host; import com.yahoo.vespa.model.HostResource; import com.yahoo.vespa.model.admin.clustercontroller.ClusterControllerClusterVerifier; @@ -80,7 +81,7 @@ public class ContainerClusterTest { .zone(new Zone(SystemName.cd, Environment.test, RegionName.from("some-region"))) .build(); MockRoot root = new MockRoot("foo", state); - ContainerCluster cluster = new ContainerCluster(root, "container0", "container1"); + ContainerCluster cluster = new ContainerCluster(root, "container0", "container1", RankProfileList.empty); ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder(); cluster.getConfig(builder); ConfigserverConfig config = new ConfigserverConfig(builder); @@ -111,8 +112,8 @@ public class ContainerClusterTest { MockRoot root = new MockRoot("foo", state); ContainerCluster cluster = extraComponents.isPresent() - ? new ContainerCluster(root, "container0", "container1", extraComponents.get()) - : new ContainerCluster(root, "container0", "container1"); + ? new ContainerCluster(root, "container0", "container1", extraComponents.get(), RankProfileList.empty) + : new ContainerCluster(root, "container0", "container1", RankProfileList.empty); if (isCombinedCluster) cluster.setHostClusterId("test-content-cluster"); cluster.setMemoryPercentage(memoryPercentage); @@ -257,7 +258,7 @@ public class ContainerClusterTest { public void requireThatRoutingProviderIsDisabledForNonHosted() { DeployState state = new DeployState.Builder().properties(new DeployProperties.Builder().hostedVespa(false).build()).build(); MockRoot root = new MockRoot("foo", state); - ContainerCluster cluster = new ContainerCluster(root, "container0", "container1"); + ContainerCluster cluster = new ContainerCluster(root, "container0", "container1", RankProfileList.empty); RoutingProviderConfig.Builder builder = new RoutingProviderConfig.Builder(); cluster.getConfig(builder); RoutingProviderConfig config = new RoutingProviderConfig(builder); @@ -281,7 +282,7 @@ public class ContainerClusterTest { } private static ContainerCluster newContainerCluster() { - ContainerCluster cluster = new ContainerCluster(null, "subId", "name"); + ContainerCluster cluster = new ContainerCluster(null, "subId", "name", RankProfileList.empty); addContainer(cluster, "c1", "host-c1"); addContainer(cluster, "c2", "host-c2"); return cluster; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/configserver/ConfigserverClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/configserver/ConfigserverClusterTest.java index d4209c9c788..b4ad2ddbd21 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/configserver/ConfigserverClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/configserver/ConfigserverClusterTest.java @@ -38,7 +38,7 @@ public class ConfigserverClusterTest { new ConfigServerContainerModelBuilder(new TestOptions().rpcPort(12345).useVespaVersionInRequest(true) .hostedVespa(true).environment("test").region("bar") .numParallelTenantLoaders(99)) - .build(new DeployState.Builder().build(), null, root, XML.getDocument(services).getDocumentElement()); + .build(new DeployState.Builder().build(), null, null, root, XML.getDocument(services).getDocumentElement()); root.freezeModelTopology(); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/docproc/StandaloneDocprocContainerTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/docproc/StandaloneDocprocContainerTest.java index 8995d6b80b0..3edc70833d8 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/docproc/StandaloneDocprocContainerTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/docproc/StandaloneDocprocContainerTest.java @@ -18,14 +18,13 @@ import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @since 5.1.14 + * @author Einar M R Rosenvinge */ public class StandaloneDocprocContainerTest extends DomBuilderTest { public ContainerCluster setupCluster(boolean standalone) { ContainerModelBuilder builder = new ContainerModelBuilder(standalone, Networking.disable); - ContainerModel model = builder.build(DeployState.createTestState(), null, root, servicesXml()); + ContainerModel model = builder.build(DeployState.createTestState(), null, null, root, servicesXml()); if (!standalone) model.getCluster().getDocproc().getChains().addServersAndClientsForChains(); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/http/FilterBindingsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/http/FilterBindingsTest.java index 420115627dc..9d5508cca75 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/http/FilterBindingsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/http/FilterBindingsTest.java @@ -38,7 +38,7 @@ public class FilterBindingsTest extends DomBuilderTest { private void buildContainerCluster(Element containerElem) throws SAXException, IOException { - ContainerModel model = new ContainerModelBuilder(true, Networking.enable).build(DeployState.createTestState(), null, root, containerElem); + ContainerModel model = new ContainerModelBuilder(true, Networking.enable).build(DeployState.createTestState(), null, null, root, containerElem); root.freezeModelTopology(); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java index e3dfa093735..f94ebab42a9 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTest.java @@ -513,7 +513,7 @@ public class ContainerModelBuilderTest extends ContainerModelBuilderTestBase { "</jdisc>"); DeployState deployState = new DeployState.Builder().zone(new Zone(Environment.dev, RegionName.from("us-east-1"))).build(); - createModel(root, deployState, clusterElem); + createModel(root, deployState, null, clusterElem); assertEquals(0, getContainerCluster("default").serviceAliases().size()); assertEquals(0, getContainerCluster("default").endpointAliases().size()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java index 47f8a1bbe29..e46e736dcd6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilderTestBase.java @@ -5,6 +5,7 @@ import com.yahoo.component.ComponentId; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.test.MockRoot; import com.yahoo.container.ComponentsConfig; +import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ContainerCluster; import com.yahoo.vespa.model.container.ContainerModel; import com.yahoo.vespa.model.container.component.Component; @@ -22,7 +23,6 @@ import java.util.Collections; * not be done when using this class * * @author gjoranv - * @since 5.5 */ public abstract class ContainerModelBuilderTestBase { @@ -32,17 +32,18 @@ public abstract class ContainerModelBuilderTestBase { " </nodes>"; protected MockRoot root; - public static void createModel(MockRoot root, DeployState deployState, Element... containerElems) throws SAXException, IOException { + public static void createModel(MockRoot root, DeployState deployState, VespaModel vespaModel, Element... containerElems) { for (Element containerElem : containerElems) { - ContainerModel model = new ContainerModelBuilder(false, ContainerModelBuilder.Networking.enable).build(deployState, null, root, containerElem); + ContainerModel model = new ContainerModelBuilder(false, ContainerModelBuilder.Networking.enable) + .build(deployState, vespaModel, null, root, containerElem); ContainerCluster cluster = model.getCluster(); generateDefaultSearchChains(cluster); } root.freezeModelTopology(); } - public static void createModel(MockRoot root, Element... containerElems) throws SAXException, IOException { - createModel(root, DeployState.createTestState(), containerElems); + public static void createModel(MockRoot root, Element... containerElems) { + createModel(root, DeployState.createTestState(), null, containerElems); } private static void generateDefaultSearchChains(ContainerCluster cluster) { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/DocprocBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/DocprocBuilderTest.java index 31191d0c5fb..f4d3fbc782c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/DocprocBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/DocprocBuilderTest.java @@ -49,7 +49,7 @@ public class DocprocBuilderTest extends DomBuilderTest { @Before public void setupCluster() { - ContainerModel model = new ContainerModelBuilder(false, Networking.disable).build(DeployState.createTestState(), null, root, servicesXml()); + ContainerModel model = new ContainerModelBuilder(false, Networking.disable).build(DeployState.createTestState(), null, null, root, servicesXml()); cluster = model.getCluster(); cluster.getDocproc().getChains().addServersAndClientsForChains(); root.freezeModelTopology(); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/IdentityBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/IdentityBuilderTest.java index d3ad2ccc721..0fd138a7943 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/IdentityBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/IdentityBuilderTest.java @@ -34,7 +34,7 @@ public class IdentityBuilderTest extends ContainerModelBuilderTestBase { .withDeploymentSpec(deploymentXml) .build(); - createModel(root, DeployState.createTestState(applicationPackage), clusterElem); + createModel(root, DeployState.createTestState(applicationPackage), null, clusterElem); IdentityConfig identityConfig = root.getConfig(IdentityConfig.class, "default/component/" + IdentityProvider.CLASS); assertEquals("domain", identityConfig.domain()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/RoutingBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/RoutingBuilderTest.java index a2f32694340..3d61ec3a3af 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/RoutingBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/RoutingBuilderTest.java @@ -70,7 +70,7 @@ public class RoutingBuilderTest extends ContainerModelBuilderTestBase { .build(); root = new MockRoot("root", deployState); - createModel(root, deployState, clusterElem); + createModel(root, deployState, null, clusterElem); ContainerCluster cluster = getContainerCluster("default"); return cluster.getContainers().get(0); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterUtils.java b/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterUtils.java index b0d6c94947a..89d81ee262b 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterUtils.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/utils/ContentClusterUtils.java @@ -61,8 +61,8 @@ public class ContentClusterUtils { Document doc = XML.getDocument(clusterXml); Admin admin = new Admin(root, new DefaultMonitoring("vespa", 60), new Metrics(), Collections.emptyMap(), false, new FileDistributionConfigProducer(root, new MockFileRegistry(), null)); - ConfigModelContext context = ConfigModelContext.create(null, root.getDeployState(), null, root, null); - + ConfigModelContext context = ConfigModelContext.create(null, root.getDeployState(), null,null, root, null); + return new ContentCluster.Builder(admin).build(Collections.emptyList(), context, doc.getDocumentElement()); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/generic/GenericServicesModelTest.java b/config-model/src/test/java/com/yahoo/vespa/model/generic/GenericServicesModelTest.java index 21521ba6b7e..91e9e5e656f 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/generic/GenericServicesModelTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/generic/GenericServicesModelTest.java @@ -32,7 +32,7 @@ public class GenericServicesModelTest { @Test public void test_generic_services_model() { MockRoot root = new MockRoot(); - GenericServicesModel model = new GenericServicesModel(ConfigModelContext.create(null, root, "foo")); + GenericServicesModel model = new GenericServicesModel(ConfigModelContext.create(null, null, root, "foo")); assertThat(model.serviceClusters().size(), is(0)); model.addCluster(new ServiceCluster(root, "mycluster", "/bin/foo")); assertThat(model.serviceClusters().size(), is(1)); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileRegistry.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileRegistry.java index 8acd1ec4248..0363b50815b 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileRegistry.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileRegistry.java @@ -16,9 +16,6 @@ public class QueryProfileRegistry extends ComponentRegistry<QueryProfile> { private QueryProfileTypeRegistry queryProfileTypeRegistry = new QueryProfileTypeRegistry(); - /** The current default instance of this registry */ - private static QueryProfileRegistry instance = new QueryProfileRegistry(); - /** Register this type by its id */ public void register(QueryProfile profile) { super.register(profile.getId(), profile); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index ec7bdcf5f2b..045844ee219 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -1,5 +1,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.collections.Pair; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -97,6 +98,37 @@ public class ImportedModel { void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } /** + * Returns all the outputs of this by name. The names consist of one to three parts + * separated by dot, where the first part is the model name, the second is the signature name + * if signatures are used, or the expression name if signatures are not used and there are multiple + * expressions, and the third is the output name if signature names are used. + */ + public List<Pair<String, RankingExpression>> outputExpressions() { + List<Pair<String, RankingExpression>> names = new ArrayList<>(); + for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { + for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) + names.add(new Pair<>(name + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), + expressions().get(outputEntry.getValue()))); + if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs + names.add(new Pair<>(name + "." + signatureEntry.getKey(), + expressions().get(signatureEntry.getKey()))); + } + if (signatures().isEmpty()) { // fallback for models without signatures + if (expressions().size() == 1) {// Use just model name + names.add(new Pair<>(name, + expressions().values().iterator().next())); + } + else { + for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { + names.add(new Pair<>(name + "." + expressionEntry.getKey(), + expressionEntry.getValue())); + } + } + } + return names; + } + + /** * A signature is a set of named inputs and outputs, where the inputs maps to argument * ("placeholder") names+types, and outputs maps to expressions nodes. * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java index 3fa6141a696..827b1911369 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -6,6 +6,9 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import java.util.Optional; /** @@ -63,9 +66,24 @@ public class ImportedModels { return importedModels.get(toName(modelPath)); } + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedModel> all() { + return importedModels.values(); + } + private static String toName(File modelPath) { - Path localPath = Path.fromString(modelPath.toString()).getChildPath(); - return localPath.toString().replace("/", "_").replace('.', '_'); + String localPath = concatenateAfterModelsDirectory(Path.fromString(modelPath.toString())); + return localPath.replace('.', '_'); + } + + private static String concatenateAfterModelsDirectory(Path path) { + boolean afterModels = false; + StringBuilder result = new StringBuilder(); + for (String element : path.elements()) { + if (afterModels) result.append(element).append("_"); + if (element.equals("models")) afterModels = true; + } + return result.substring(0, result.length()-1); } } |