diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-25 15:49:22 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-25 15:49:22 -0700 |
commit | 11884899e39c54abeb79bacbe723df0ff34ce869 (patch) | |
tree | 674025004f825c9cc12a075f992c0b2d1d45509e | |
parent | 0246064bbfb9657515f516e2fea12d593cd13016 (diff) |
Revert "Merge pull request #7094 from vespa-engine/revert-7070-bratseth/rank-type-information-2"
This reverts commit 0246064bbfb9657515f516e2fea12d593cd13016, reversing
changes made to f627463a8100090ec109d27c3aeb439a3395a34f.
35 files changed, 452 insertions, 241 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 16e494c2db1..937151c0d3a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -449,7 +449,7 @@ public class RankProfile implements Serializable, Cloneable { addRankProperty(new RankProperty(name, parameter)); } - public void addRankProperty(RankProperty rankProperty) { + private void addRankProperty(RankProperty rankProperty) { // Just the usual multimap semantics here List<RankProperty> properties = rankProperties.get(rankProperty.getName()); if (properties == null) { @@ -541,15 +541,14 @@ public class RankProfile implements Serializable, Cloneable { /** Returns an unmodifiable view of the functions in this */ public Map<String, RankingExpressionFunction> getFunctions() { - if (functions.size() == 0 && getInherited() == null) return Collections.emptyMap(); - if (functions.size() == 0) return getInherited().getFunctions(); + if (functions.isEmpty() && getInherited() == null) return Collections.emptyMap(); + if (functions.isEmpty()) return getInherited().getFunctions(); if (getInherited() == null) return Collections.unmodifiableMap(functions); // Neither is null Map<String, RankingExpressionFunction> allFunctions = new LinkedHashMap<>(getInherited().getFunctions()); allFunctions.putAll(functions); return Collections.unmodifiableMap(allFunctions); - } public int getKeepRankCount() { @@ -695,7 +694,7 @@ public class RankProfile implements Serializable, Cloneable { for (Map.Entry<String, RankingExpressionFunction> entry : functions.entrySet()) { RankingExpressionFunction rankingExpressionFunction = entry.getValue(); RankingExpression compiled = compile(rankingExpressionFunction.function().getBody(), queryProfiles, importedModels, getConstants(), inlineFunctions, expressionTransforms); - compiledFunctions.put(entry.getKey(), rankingExpressionFunction.withBody(compiled)); + compiledFunctions.put(entry.getKey(), rankingExpressionFunction.withExpression(compiled)); } return compiledFunctions; } @@ -898,7 +897,7 @@ public class RankProfile implements Serializable, Cloneable { /** A function in a rank profile */ public static class RankingExpressionFunction { - private final ExpressionFunction function; + private ExpressionFunction function; /** True if this should be inlined into calling expressions. Useful for very cheap functions. */ private final boolean inline; @@ -908,13 +907,17 @@ public class RankProfile implements Serializable, Cloneable { this.inline = inline; } + public void setReturnType(TensorType type) { + this.function = function.withReturnType(type); + } + public ExpressionFunction function() { return function; } public boolean inline() { return inline && function.arguments().isEmpty(); // only inline no-arg functions; } - public RankingExpressionFunction withBody(RankingExpression expression) { + public RankingExpressionFunction withExpression(RankingExpression expression) { return new RankingExpressionFunction(function.withBody(expression), inline); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index f42d5de21e8..a988da9664e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -59,9 +59,6 @@ public class Search implements Serializable, ImmutableSearch { // Field sets private FieldSets fieldSets = new FieldSets(); - // Whether or not this object has been processed. - private boolean processed; - // The unique name of this search definition. private String name; @@ -585,17 +582,6 @@ public class Search implements Serializable, ImmutableSearch { return false; } - public void process() { - if (processed) { - throw new IllegalStateException("Search '" + getName() + "' already processed."); - } - processed = true; - } - - public boolean isProcessed() { - return processed; - } - /** * The field set settings for this search * diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java index 3c2ebc058ac..151ad02a3fa 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java @@ -188,10 +188,6 @@ public class SearchBuilder { throw new IllegalArgumentException("Search has no name."); } String rawName = rawSearch.getName(); - if (rawSearch.isProcessed()) { - throw new IllegalArgumentException("A search definition with a search section called '" + rawName + - "' has already been processed."); - } for (Search search : searchList) { if (rawName.equals(search.getName())) { throw new IllegalArgumentException("A search definition with a search section called '" + rawName + @@ -247,8 +243,7 @@ public class SearchBuilder { DocumentModelBuilder builder = new DocumentModelBuilder(model); for (Search search : new SearchOrderer().order(searchList)) { - new FieldOperationApplierForSearch().process(search); - // These two needed for a couple of old unit tests, ideally these are just read from app + new FieldOperationApplierForSearch().process(search); // TODO: Why is this not in the regular list? process(search, deployLogger, new QueryProfiles(queryProfileRegistry), validate); built.add(search); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 9a00ee5bbd0..7c2d9a3b0ad 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -74,21 +74,11 @@ public class DerivedConfiguration { QueryProfileRegistry queryProfiles, ImportedModels importedModels) { Validator.ensureNotNull("Search definition", search); - if ( ! search.isProcessed()) { - throw new IllegalArgumentException("Search '" + search.getName() + "' not processed."); - } this.search = search; if ( ! search.isDocumentsOnly()) { streamingFields = new VsmFields(search); streamingSummary = new VsmSummary(search); } - if (abstractSearchList != null) { - for (Search abstractSearch : abstractSearchList) { - if (!abstractSearch.isProcessed()) { - throw new IllegalArgumentException("Search '" + search.getName() + "' not processed."); - } - } - } if ( ! search.isDocumentsOnly()) { attributeFields = new AttributeFields(search); summaries = new Summaries(search, deployLogger); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index c041d5c6a89..279b5334187 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import java.nio.charset.Charset; @@ -23,6 +24,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * A rank profile derived from a search definition, containing exactly the features available natively in the server @@ -180,32 +182,39 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) { if (functions.isEmpty()) return; - Map<String, ExpressionFunction> expressionFunctions = new LinkedHashMap<>(); - for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : functions.entrySet()) { - expressionFunctions.put(function.getKey(), function.getValue().function()); - } + + List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); Map<String, String> functionProperties = new LinkedHashMap<>(); - functionProperties.putAll(deriveFunctionProperties(expressionFunctions)); + functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions)); + if (firstPhaseRanking != null) { - functionProperties.putAll(firstPhaseRanking.getRankProperties(new ArrayList<>(expressionFunctions.values()))); + functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions)); } if (secondPhaseRanking != null) { - functionProperties.putAll(secondPhaseRanking.getRankProperties(new ArrayList<>(expressionFunctions.values()))); + functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions)); } for (Map.Entry<String, String> e : functionProperties.entrySet()) { rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue())); } - SerializationContext context = new SerializationContext(expressionFunctions.values(), null, functionProperties); + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); replaceFunctionSummaryFeatures(context); } - private Map<String, String> deriveFunctionProperties(Map<String, ExpressionFunction> functions) { - SerializationContext context = new SerializationContext(functions); - for (Map.Entry<String, ExpressionFunction> e : functions.entrySet()) { - String expression = e.getValue().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); - context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expression); + private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, + List<ExpressionFunction> functionExpressions) { + SerializationContext context = new SerializationContext(functionExpressions); + for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { + String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); + + for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) + context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); + if (e.getValue().function().returnType().isPresent()) + context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get()); + else if (e.getValue().function().arguments().isEmpty()) + throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved"); } return context.serializedFunctions(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index 8c8c32389e2..15d295736c1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -7,10 +7,8 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.processing.multifieldresolver.RankProfileTypeSettingsProcessor; import com.yahoo.vespa.model.container.search.QueryProfiles; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.List; /** * Executor of processors. This defines the right order of processor execution. @@ -75,12 +73,20 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, - RankingExpressionTypeValidator::new, + RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, IndexingValues::new); } + /** Processors of rank profiles only (those who tolerate and so something useful when the search field is null) */ + private Collection<ProcessorFactory> rankProfileProcessors() { + return Arrays.asList( + RankProfileTypeSettingsProcessor::new, + ReservedFunctionNames::new, + RankingExpressionTypeResolver::new); + } + /** * Runs all search processors on the given {@link Search} object. These will modify the search object, <b>possibly * exchanging it with another</b>, as well as its document types. @@ -93,12 +99,26 @@ public class Processing { public void process(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles, boolean validate, boolean documentsOnly) { Collection<ProcessorFactory> factories = processors(); - search.process(); factories.stream() .map(factory -> factory.create(search, deployLogger, rankProfileRegistry, queryProfiles)) .forEach(processor -> processor.process(validate, documentsOnly)); } + /** + * Runs rank profiles processors only. + * + * @param deployLogger The log to log messages and warnings for application deployment to + * @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry} + * @param queryProfiles The query profiles contained in the application this search is part of. + */ + public void processRankProfiles(DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles, boolean validate, boolean documentsOnly) { + Collection<ProcessorFactory> factories = rankProfileProcessors(); + factories.stream() + .map(factory -> factory.create(null, deployLogger, rankProfileRegistry, queryProfiles)) + .forEach(processor -> processor.process(validate, documentsOnly)); + } + @FunctionalInterface public interface ProcessorFactory { Processor create(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java index 102d1910360..4c8b5910b78 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java @@ -7,39 +7,44 @@ import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.vespa.model.container.search.QueryProfiles; +import java.util.Map; + /** - * Validates the types of all ranking expressions under a search instance: + * Resolves and assigns types to all functions in a ranking expression, and + * validates the types of all ranking expressions under a search instance: * Some operators constrain the types of inputs, and first-and second-phase expressions - * must return scalar values. In addition, the existence of all referred attribute, query and constant + * must return scalar values. + * + * In addition, the existence of all referred attribute, query and constant * features is ensured. * * @author bratseth */ -public class RankingExpressionTypeValidator extends Processor { +public class RankingExpressionTypeResolver extends Processor { private final QueryProfileRegistry queryProfiles; - public RankingExpressionTypeValidator(Search search, - DeployLogger deployLogger, - RankProfileRegistry rankProfileRegistry, - QueryProfiles queryProfiles) { + public RankingExpressionTypeResolver(Search search, + DeployLogger deployLogger, + RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles) { super(search, deployLogger, rankProfileRegistry, queryProfiles); this.queryProfiles = queryProfiles.getRegistry(); } @Override public void process(boolean validate, boolean documentsOnly) { - if ( ! validate) return; if (documentsOnly) return; for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { try { - validate(profile); + resolveTypesIn(profile, validate); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("In " + search + ", " + profile, e); @@ -47,20 +52,34 @@ public class RankingExpressionTypeValidator extends Processor { } } - /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */ - private void validate(RankProfile profile) { - TypeContext context = profile.typeContext(queryProfiles); - profile.getSummaryFeatures().forEach(f -> ensureValid(f, "summary feature " + f, context)); - ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context); - ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context); + /** + * Resolves the types of all functions in the given profile + * + * @throws IllegalArgumentException if validate is true and the given rank profile does not produce valid types + */ + private void resolveTypesIn(RankProfile profile, boolean validate) { + TypeContext<Reference> context = profile.typeContext(queryProfiles); + for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { + if ( ! function.getValue().function().arguments().isEmpty()) continue; + TensorType type = resolveType(function.getValue().function().getBody(), + "function '" + function.getKey() + "'", + context); + function.getValue().setReturnType(type); + } + + if (validate) { + profile.getSummaryFeatures().forEach(f -> resolveType(f, "summary feature " + f, context)); + ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context); + ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context); + } } - private TensorType ensureValid(RankingExpression expression, String expressionDescription, TypeContext context) { + private TensorType resolveType(RankingExpression expression, String expressionDescription, TypeContext context) { if (expression == null) return null; - return ensureValid(expression.getRoot(), expressionDescription, context); + return resolveType(expression.getRoot(), expressionDescription, context); } - private TensorType ensureValid(ExpressionNode expression, String expressionDescription, TypeContext context) { + private TensorType resolveType(ExpressionNode expression, String expressionDescription, TypeContext context) { TensorType type; try { type = expression.type(context); @@ -75,7 +94,7 @@ public class RankingExpressionTypeValidator extends Processor { private void ensureValidDouble(RankingExpression expression, String expressionDescription, TypeContext context) { if (expression == null) return; - TensorType type = ensureValid(expression, expressionDescription, context); + TensorType type = resolveType(expression, expressionDescription, context); if ( ! type.equals(TensorType.empty)) throw new IllegalArgumentException("The " + expressionDescription + " must produce a double " + "(a tensor with no dimensions), but produces " + type); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java index ec4cbdfe58b..3bde76c1c79 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java @@ -44,6 +44,7 @@ public class RankProfileTypeSettingsProcessor extends Processor { } private void processAttributeFields() { + if (search == null) return; // we're processing global profiles for (SDField field : search.allConcreteFields()) { Attribute attribute = field.getAttributes().get(field.getName()); if (attribute != null && attribute.tensorType().isPresent()) { @@ -53,6 +54,7 @@ public class RankProfileTypeSettingsProcessor extends Processor { } private void processImportedFields() { + if (search == null) return; // we're processing global profiles Optional<ImportedFields> importedFields = search.importedFields(); if (importedFields.isPresent()) { importedFields.get().fields().forEach((fieldName, field) -> processImportedField(field)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Service.java b/config-model/src/main/java/com/yahoo/vespa/model/Service.java index 620e44bc11a..29ec26b06d2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/Service.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/Service.java @@ -7,7 +7,7 @@ import java.util.HashMap; import java.util.Optional; /** - * Representation of a process which runs a service + * Representation of a markProcessed which runs a service * * @author gjoranv */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 4b70b1b5ae2..13304ea10ee 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -32,7 +32,9 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; +import com.yahoo.searchdefinition.processing.Processing; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.ml.ConvertedModel; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; @@ -168,7 +170,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri createGlobalRankProfiles(deployState.getImportedModels(), deployState.rankProfileRegistry(), - deployState.getQueryProfiles().getRegistry()); + deployState.getQueryProfiles()); this.rankProfileList = new RankProfileList(null, // null search -> global rankingConstants, AttributeFields.empty, @@ -219,26 +221,23 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri /** Adds generic application specific clusters of services */ private void addServiceClusters(ApplicationPackage app, VespaModelBuilder builder) { - for (ServiceCluster sc : builder.getClusters(app, this)) - serviceClusters.add(sc); + serviceClusters.addAll(builder.getClusters(app, this)); } /** - * Creates a rank profile not attached to any search definition, for each imported model in the application package + * Creates a rank profile not attached to any search definition, for each imported model in the application package, + * and adds it to the given rank profile registry. */ - private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels, - RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfiles) { - List<RankProfile> profiles = new ArrayList<>(); + private void createGlobalRankProfiles(ImportedModels importedModels, + RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles) { if ( ! importedModels.all().isEmpty()) { // models/ directory is available for (ImportedModel model : importedModels.all()) { RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), - model.name(), profile, queryProfiles, model); - for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue()), false); - } + model.name(), profile, queryProfiles.getRegistry(), model); + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); } } else { // generated and stored model information may be available instead @@ -248,12 +247,12 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); - for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue()), false); - } + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); } } - return ImmutableList.copyOf(profiles); + new Processing().processRankProfiles(deployState.getDeployLogger(), + rankProfileRegistry, + queryProfiles, true, false); } /** Returns the global rank profiles as a rank profile list */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index adf5c81283e..fb0109ed32e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -48,6 +48,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,14 +68,14 @@ public class ConvertedModel { private final ModelName modelName; private final String modelDescription; - private final ImmutableMap<String, RankingExpression> expressions; + private final ImmutableMap<String, ExpressionFunction> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ private final Optional<ImportedModel> sourceModel; private ConvertedModel(ModelName modelName, String modelDescription, - Map<String, RankingExpression> expressions, + Map<String, ExpressionFunction> expressions, Optional<ImportedModel> sourceModel) { this.modelName = modelName; this.modelDescription = modelDescription; @@ -132,23 +133,23 @@ public class ConvertedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public Map<String, RankingExpression> expressions() { return expressions; } + public Map<String, ExpressionFunction> expressions() { return expressions; } /** * Returns the expression matching the given arguments. */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { - RankingExpression expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we can verify - verifyRequiredFunctions(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles()); - return expression.getRoot(); + ExpressionFunction expression = selectExpression(arguments); + if (sourceModel.isPresent()) // we should verify + verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); + return expression.getBody().getRoot(); } - private RankingExpression selectExpression(FeatureArguments arguments) { + private ExpressionFunction selectExpression(FeatureArguments arguments) { if (expressions.isEmpty()) throw new IllegalArgumentException("No expressions available in " + this); - RankingExpression expression = expressions.get(arguments.toName()); + ExpressionFunction expression = expressions.get(arguments.toName()); if (expression != null) return expression; if ( ! arguments.signature().isPresent()) { @@ -158,7 +159,7 @@ public class ConvertedModel { } if ( ! arguments.output().isPresent()) { - List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix = + List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix = expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList()); if (entriesWithTheRightPrefix.size() < 1) throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + @@ -179,10 +180,10 @@ public class ConvertedModel { // ----------------------- Static model conversion/storage below here - private static Map<String, RankingExpression> convertAndStore(ImportedModel model, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelStore store) { + private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelStore store) { // Add constants Set<String> constantsReplacedByFunctions = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -193,8 +194,8 @@ public class ConvertedModel { addGeneratedFunctions(model, profile); // Add expressions - Map<String, RankingExpression> expressions = new HashMap<>(); - for (Pair<String, RankingExpression> output : model.outputExpressions()) { + Map<String, ExpressionFunction> expressions = new HashMap<>(); + for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { addExpression(output.getSecond(), output.getFirst(), constantsReplacedByFunctions, model, store, profile, queryProfiles, @@ -210,21 +211,21 @@ public class ConvertedModel { return expressions; } - private static void addExpression(RankingExpression expression, + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, ImportedModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Map<String, RankingExpression> expressions) { - expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions); - reduceBatchDimensions(expression, model, profile, queryProfiles); + Map<String, ExpressionFunction> expressions) { + expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); + reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } - private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) { + private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -290,15 +291,15 @@ public class ConvertedModel { } /** - * Verify that the functions referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredFunctions. + * Verify that the inputs declared in the given expression exists in the given rank profile as functions, + * and return tensors of the correct types. */ - private static void verifyRequiredFunctions(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { + private static void verifyInputs(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { - TensorType requiredType = model.requiredFunctions().get(functionName); + TensorType requiredType = model.inputs().get(functionName); if (requiredType == null) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); @@ -375,7 +376,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredFunctions().containsKey(referenceNode.getName())) { + if (model.inputs().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -383,7 +384,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredFunctions().containsKey(referenceNode.getName())) { + if (model.inputs().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -451,7 +452,8 @@ public class ConvertedModel { Set<String> constantsReplacedByFunctions) { if (constantsReplacedByFunctions.isEmpty()) return expression; return new RankingExpression(expression.getName(), - replaceConstantsByFunctions(expression.getRoot(), constantsReplacedByFunctions)); + replaceConstantsByFunctions(expression.getRoot(), + constantsReplacedByFunctions)); } private static ExpressionNode replaceConstantsByFunctions(ExpressionNode node, Set<String> constantsReplacedByFunctions) { @@ -524,19 +526,21 @@ public class ConvertedModel { * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part * is always the model name */ - void writeExpression(String name, RankingExpression expression) { - application.getFile(modelFiles.expressionPath(name)) - .writeFile(new StringReader(expression.getRoot().toString())); + void writeExpression(String name, ExpressionFunction expression) { + StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString()); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) + b.append('\n').append(input.getKey()).append('\t').append(input.getValue()); + application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString())); } - Map<String, RankingExpression> readExpressions() { - Map<String, RankingExpression> expressions = new HashMap<>(); + Map<String, ExpressionFunction> readExpressions() { + Map<String, ExpressionFunction> expressions = new HashMap<>(); ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); for (ApplicationFile expressionFile : expressionPath.listFiles()) { - try (Reader reader = new BufferedReader(expressionFile.createReader())){ + try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){ String name = expressionFile.getPath().getName(); - expressions.put(name, new RankingExpression(name, reader)); + expressions.put(name, readExpression(name, reader)); } catch (IOException e) { throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e); @@ -548,8 +552,22 @@ public class ConvertedModel { return expressions; } + private ExpressionFunction readExpression(String name, BufferedReader reader) + throws IOException, ParseException { + // First line is expression + RankingExpression expression = new RankingExpression(name, reader.readLine()); + // Next lines are inputs on the format name\ttensorTypeSpec + Map<String, TensorType> inputs = new LinkedHashMap<>(); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + inputs.put(parts[0], TensorType.fromSpec(parts[1])); + } + return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty()); + } + /** Adds this function expression to the application package so it can be read later. */ - void writeFunction(String name, RankingExpression expression) { + public void writeFunction(String name, RankingExpression expression) { application.getFile(modelFiles.functionsPath()).appendFile(name + "\t" + expression.getRoot().toString() + "\n"); } @@ -561,20 +579,20 @@ public class ConvertedModel { if ( ! file.exists()) return Collections.emptyList(); List<Pair<String, RankingExpression>> functions = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[0], parts[1]); - functions.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + name, e); + try (BufferedReader reader = new BufferedReader(file.createReader())) { + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[0], parts[1]); + functions.add(new Pair<>(name, expression)); + } catch (ParseException e) { + throw new IllegalStateException("Could not parse " + name, e); + } } + return functions; } - return functions; } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/config-model/src/test/derived/gemini2/gemini.sd b/config-model/src/test/derived/gemini2/gemini.sd index 01e20c1b30a..8a570e58fa8 100644 --- a/config-model/src/test/derived/gemini2/gemini.sd +++ b/config-model/src/test/derived/gemini2/gemini.sd @@ -2,6 +2,12 @@ search gemini { document gemini { + field right type string { + indexing: attribute + } + field wrong type string { + indexing: attribute + } } rank-profile test { diff --git a/config-model/src/test/derived/gemini2/rank-profiles.cfg b/config-model/src/test/derived/gemini2/rank-profiles.cfg index aa4f963320d..2b73e923c88 100644 --- a/config-model/src/test/derived/gemini2/rank-profiles.cfg +++ b/config-model/src/test/derived/gemini2/rank-profiles.cfg @@ -21,9 +21,13 @@ rankprofile[].fef.property[].name "rankingExpression(wrapper1@2d437c13405e61d6). rankprofile[].fef.property[].value "rankingExpression(wrapper2@2d437c13405e61d6)" rankprofile[].fef.property[].name "rankingExpression(toplevel).rankingScript" rankprofile[].fef.property[].value "rankingExpression(wrapper1@2d437c13405e61d6)" +rankprofile[].fef.property[].name "rankingExpression(toplevel).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(wrapper2@8fc8470e911f253f).rankingScript" rankprofile[].fef.property[].value "attribute(wrong)" rankprofile[].fef.property[].name "rankingExpression(wrapper1@8fc8470e911f253f).rankingScript" rankprofile[].fef.property[].value "rankingExpression(wrapper2@8fc8470e911f253f)" rankprofile[].fef.property[].name "rankingExpression(interfering).rankingScript" rankprofile[].fef.property[].value "rankingExpression(wrapper1@8fc8470e911f253f)" +rankprofile[].fef.property[].name "rankingExpression(interfering).type" +rankprofile[].fef.property[].value "tensor()" diff --git a/config-model/src/test/derived/rankexpression/rank-profiles.cfg b/config-model/src/test/derived/rankexpression/rank-profiles.cfg index 9629ad863d4..d109ca4f0ec 100644 --- a/config-model/src/test/derived/rankexpression/rank-profiles.cfg +++ b/config-model/src/test/derived/rankexpression/rank-profiles.cfg @@ -128,6 +128,8 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript rankprofile[].fef.property[].value "4 * (var1 + var2)" rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript" rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)" +rankprofile[].fef.property[].name "rankingExpression(myfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86).rankingScript" rankprofile[].fef.property[].value "4 * (match + rankBoost)" rankprofile[].fef.property[].name "vespa.rank.firstphase" @@ -145,10 +147,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript rankprofile[].fef.property[].value "4 * (var1 + var2)" rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript" rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)" +rankprofile[].fef.property[].name "rankingExpression(myfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript" rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript" rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(fourtimessum@2b1138e8965e7ff5.67f1e87166cfef86).rankingScript" rankprofile[].fef.property[].value "4 * (match + match)" rankprofile[].fef.property[].name "vespa.rank.firstphase" @@ -164,11 +172,15 @@ rankprofile[].fef.property[].value "rankingExpression(mysummaryfeature)" rankprofile[].name "macros3" rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).rankingScript" rankprofile[].fef.property[].value "5" +rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "vespa.summary.feature" rankprofile[].fef.property[].value "rankingExpression(matches(title,rankingExpression(onlyusedinsummaryfeature)))" rankprofile[].name "macros3-inherited" rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).rankingScript" rankprofile[].fef.property[].value "5" +rankprofile[].fef.property[].name "rankingExpression(onlyusedinsummaryfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "vespa.summary.feature" rankprofile[].fef.property[].value "rankingExpression(matches(title,rankingExpression(onlyusedinsummaryfeature)))" rankprofile[].name "macros-inherited" @@ -178,10 +190,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript rankprofile[].fef.property[].value "4 * (var1 + var2)" rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript" rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)" +rankprofile[].fef.property[].name "rankingExpression(myfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript" rankprofile[].fef.property[].value "80 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript" rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(fourtimessum@2b1138e8965e7ff5.67f1e87166cfef86).rankingScript" rankprofile[].fef.property[].value "4 * (match + match)" rankprofile[].fef.property[].name "vespa.rank.firstphase" @@ -203,10 +221,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript rankprofile[].fef.property[].value "4 * (var1 + var2)" rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript" rankprofile[].fef.property[].value "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + 30 * pow(0 - fieldMatch(description).earliness,2)" +rankprofile[].fef.property[].name "rankingExpression(myfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript" rankprofile[].fef.property[].value "80 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript" rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(fourtimessum@2b1138e8965e7ff5.67f1e87166cfef86).rankingScript" rankprofile[].fef.property[].value "4 * (match + match)" rankprofile[].fef.property[].name "vespa.rank.firstphase" @@ -228,10 +252,16 @@ rankprofile[].fef.property[].name "rankingExpression(fourtimessum).rankingScript rankprofile[].fef.property[].value "4 * (var1 + var2)" rankprofile[].fef.property[].name "rankingExpression(myfeature).rankingScript" rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(myfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).rankingScript" rankprofile[].fef.property[].value "80 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).rankingScript" rankprofile[].fef.property[].value "71 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(mysummaryfeature2).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" @@ -249,8 +279,14 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript" rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness" rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript" rankprofile[].fef.property[].value "rankingExpression(m1) * 67" +rankprofile[].fef.property[].name "rankingExpression(m2).type" +rankprofile[].fef.property[].value "tensor()" +rankprofile[].fef.property[].name "rankingExpression(m1).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript" rankprofile[].fef.property[].value "703 * fieldMatch(fromfile).completeness" +rankprofile[].fef.property[].name "rankingExpression(m4).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "vespa.rank.secondphase" rankprofile[].fef.property[].value "rankingExpression(secondphase)" rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript" @@ -260,10 +296,18 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript" rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness" rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript" rankprofile[].fef.property[].value "rankingExpression(m1) * 67" +rankprofile[].fef.property[].name "rankingExpression(m2).type" +rankprofile[].fef.property[].value "tensor()" +rankprofile[].fef.property[].name "rankingExpression(m1).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript" rankprofile[].fef.property[].value "701 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(m4).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(m3).rankingScript" rankprofile[].fef.property[].value "if (isNan(attribute(nrtgmp)) == 1, 0.0, rankingExpression(m2))" +rankprofile[].fef.property[].name "rankingExpression(m3).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "vespa.rank.secondphase" rankprofile[].fef.property[].value "rankingExpression(secondphase)" rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript" @@ -273,8 +317,14 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript" rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness" rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript" rankprofile[].fef.property[].value "rankingExpression(m1) * 67" +rankprofile[].fef.property[].name "rankingExpression(m2).type" +rankprofile[].fef.property[].value "tensor()" +rankprofile[].fef.property[].name "rankingExpression(m1).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript" rankprofile[].fef.property[].value "703 * fieldMatch(fromfile).completeness" +rankprofile[].fef.property[].name "rankingExpression(m4).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "vespa.rank.secondphase" rankprofile[].fef.property[].value "rankingExpression(secondphase)" rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript" @@ -284,12 +334,22 @@ rankprofile[].fef.property[].name "rankingExpression(m1).rankingScript" rankprofile[].fef.property[].value "700 * fieldMatch(title).completeness" rankprofile[].fef.property[].name "rankingExpression(m2).rankingScript" rankprofile[].fef.property[].value "rankingExpression(m1) * 67" +rankprofile[].fef.property[].name "rankingExpression(m2).type" +rankprofile[].fef.property[].value "tensor()" +rankprofile[].fef.property[].name "rankingExpression(m1).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(m4).rankingScript" rankprofile[].fef.property[].value "701 * fieldMatch(title).completeness" +rankprofile[].fef.property[].name "rankingExpression(m4).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(m3).rankingScript" rankprofile[].fef.property[].value "if (isNan(attribute(nrtgmp)) == 1, 0.0, rankingExpression(m2))" +rankprofile[].fef.property[].name "rankingExpression(m3).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "rankingExpression(m5).rankingScript" rankprofile[].fef.property[].value "if (isNan(attribute(glmpfw)) == 1, rankingExpression(m1), rankingExpression(m4))" +rankprofile[].fef.property[].name "rankingExpression(m5).type" +rankprofile[].fef.property[].value "tensor()" rankprofile[].fef.property[].name "vespa.rank.secondphase" rankprofile[].fef.property[].value "rankingExpression(secondphase)" rankprofile[].fef.property[].name "rankingExpression(secondphase).rankingScript" diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index cb496c06367..471343da63c 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -43,10 +43,14 @@ rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" rankprofile[].fef.property[].value "tensor(x[10],y[20])" rankprofile[].name "profile3" +rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript" +rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)" +rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type" +rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])" rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" -rankprofile[].fef.property[].value "reduce(tensor(i[10])(i) * attribute(f4), sum)" +rankprofile[].fef.property[].value "reduce(rankingExpression(joinedtensors), sum)" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor(x[2],y[])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index 5792a7997f8..54463448e28 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -36,7 +36,11 @@ search tensor { rank-profile profile3 { first-phase { - expression: sum(tensor(i[10])(i) * attribute(f4)) + expression: sum(joinedtensors) + } + + function joinedtensors() { + expression: tensor(i[10])(i) * attribute(f4) } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index 150469cc928..2c1f4c8ecb6 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -82,7 +82,7 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase new ImportedModels(), new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString()); - assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString()); + assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(3).toString()); } @Test diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java index 17bebcba70e..0ff8a5cc7ca 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java @@ -40,7 +40,7 @@ public class RankingExpressionLoopDetectionTestCase { fail("Excepted exception"); } catch (IllegalArgumentException e) { - assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> foo", + assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: foo -> foo", Exceptions.toMessageString(e)); } } @@ -75,7 +75,7 @@ public class RankingExpressionLoopDetectionTestCase { fail("Excepted exception"); } catch (IllegalArgumentException e) { - assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(5) -> foo", + assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: arg(5) -> foo -> arg(5)", Exceptions.toMessageString(e)); } } @@ -110,7 +110,7 @@ public class RankingExpressionLoopDetectionTestCase { fail("Excepted exception"); } catch (IllegalArgumentException e) { - assertEquals("In search definition 'test', rank profile 'test': The first-phase expression is invalid: Invocation loop: foo -> arg(foo) -> foo", + assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: arg(foo) -> foo -> arg(foo)", Exceptions.toMessageString(e)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index e15d4075b19..4f99922a422 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -211,12 +211,16 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase censorBindingHash(testRankProperties.get(1).toString())); assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))", censorBindingHash(testRankProperties.get(2).toString())); + assertEquals("(rankingExpression(hidden_layer).type,tensor(x[]))", + censorBindingHash(testRankProperties.get(3).toString())); assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", - testRankProperties.get(3).toString()); - assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))", testRankProperties.get(4).toString()); - assertEquals("(rankingExpression(secondphase).rankingScript,reduce(rankingExpression(final_layer), sum))", + assertEquals("(rankingExpression(final_layer).type,tensor(x[]))", testRankProperties.get(5).toString()); + assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))", + testRankProperties.get(6).toString()); + assertEquals("(rankingExpression(secondphase).rankingScript,reduce(rankingExpression(final_layer), sum))", + testRankProperties.get(7).toString()); } private QueryProfileRegistry queryProfileWith(String field, String type) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java index 0d8cbbf2e6a..1b917b6f3a3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java @@ -19,7 +19,7 @@ import static org.junit.Assert.fail; /** * @author bratseth */ -public class RankingExpressionTypeValidatorTestCase { +public class RankingExpressionTypeResolverTestCase { @Test public void tensorFirstPhaseMustProduceDouble() throws Exception { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index fd048737b43..7d62bc5089d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -41,7 +41,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { new QueryProfileRegistry(), new ImportedModels(), new AttributeFields(search)).configProperties(); - assertEquals(6, rankProperties.size()); + assertEquals(7, rankProperties.size()); assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(0).getFirst()); assertEquals("var1 * var2 + 890", rankProperties.get(0).getSecond()); @@ -49,14 +49,17 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(1).getFirst()); assertEquals("78 + closeness(distance)", rankProperties.get(1).getSecond()); - assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(5).getFirst()); - assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(5).getSecond()); + assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(6).getFirst()); + assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(6).getSecond()); - assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(3).getFirst()); - assertEquals("7 * 8 + 890", rankProperties.get(3).getSecond()); + assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(4).getFirst()); + assertEquals("7 * 8 + 890", rankProperties.get(4).getSecond()); - assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(2).getFirst()); - assertEquals("4 * 5 + 890", rankProperties.get(2).getSecond()); + assertEquals("rankingExpression(artistmatch).type", rankProperties.get(2).getFirst()); + assertEquals("tensor()", rankProperties.get(2).getSecond()); + + assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(3).getFirst()); + assertEquals("4 * 5 + 890", rankProperties.get(3).getSecond()); } @Test(expected = IllegalArgumentException.class) diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 8e721dbe503..6e3a227e2a9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -58,8 +58,8 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { "max(attribute(tensor_field_1),x)"); assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", "1 + max(attribute(tensor_field_1),x)"); - assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)", - "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)"); + assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),attribute(tensor_field_1))", + "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),attribute(tensor_field_1))"); assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", diff --git a/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java b/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java index d74f29a3b77..4dbec304bde 100644 --- a/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java +++ b/container-search/src/main/java/com/yahoo/search/searchchain/AsyncExecution.java @@ -147,7 +147,7 @@ public class AsyncExecution { } private static <T> Future<T> getFuture(Callable<T> callable) { - final FutureTask<T> future = new FutureTask<>(callable); + FutureTask<T> future = new FutureTask<>(callable); getExecutor().execute(future); return future; } diff --git a/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java b/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java index eac96e9b408..2c40165f8e5 100644 --- a/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java +++ b/processing/src/main/java/com/yahoo/processing/execution/AsyncExecution.java @@ -59,14 +59,14 @@ public class AsyncExecution { /** * Create an async execution of a single processor */ - public AsyncExecution(final Processor processor, Execution parent) { + public AsyncExecution(Processor processor, Execution parent) { this(new Execution(processor, parent)); } /** * Create an async execution of a chain */ - public AsyncExecution(final Chain<? extends Processor> chain, Execution parent) { + public AsyncExecution(Chain<? extends Processor> chain, Execution parent) { this(new Execution(chain, parent)); } @@ -81,7 +81,7 @@ public class AsyncExecution { * * @param execution the execution from which the state of this is created */ - public AsyncExecution(final Execution execution) { + public AsyncExecution(Execution execution) { this.execution = new Execution(execution); } @@ -89,7 +89,7 @@ public class AsyncExecution { * Performs an async processing. Note that the given request cannot be simultaneously * used in multiple such processings - a clone must be created for each. */ - public FutureResponse process(final Request request) { + public FutureResponse process(Request request) { return getFutureResponse(new Callable<Response>() { @Override public Response call() { @@ -99,13 +99,13 @@ public class AsyncExecution { } private static <T> Future<T> getFuture(final Callable<T> callable) { - final FutureTask<T> future = new FutureTask<>(callable); + FutureTask<T> future = new FutureTask<>(callable); executorMain.execute(future); return future; } - private FutureResponse getFutureResponse(final Callable<Response> callable, final Request request) { - final FutureResponse future = new FutureResponse(callable, execution, request); + private FutureResponse getFutureResponse(Callable<Response> callable, Request request) { + FutureResponse future = new FutureResponse(callable, execution, request); executorMain.execute(future.delegate()); return future; } @@ -118,15 +118,15 @@ public class AsyncExecution { * @return the list of responses in the same order as returned from the task collection */ // Note that this may also be achieved using guava Futures. Not sure if this should be deprecated because of it. - public static List<Response> waitForAll(final Collection<FutureResponse> tasks, final long timeout) { + public static List<Response> waitForAll(Collection<FutureResponse> tasks, long timeout) { // Copy the list in case it is modified while we are waiting - final List<FutureResponse> workingTasks = new ArrayList<>(tasks); + List<FutureResponse> workingTasks = new ArrayList<>(tasks); @SuppressWarnings({"rawtypes", "unchecked"}) - final Future task = getFuture(new Callable() { + Future task = getFuture(new Callable() { @Override public List<Future> call() { - for (final FutureResponse task : workingTasks) { + for (FutureResponse task : workingTasks) { task.get(); } return null; @@ -135,12 +135,12 @@ public class AsyncExecution { try { task.get(timeout, TimeUnit.MILLISECONDS); - } catch (final TimeoutException | InterruptedException | ExecutionException e) { + } catch (TimeoutException | InterruptedException | ExecutionException e) { // Handle timeouts below } - final List<Response> responses = new ArrayList<>(tasks.size()); - for (final FutureResponse future : workingTasks) + List<Response> responses = new ArrayList<>(tasks.size()); + for (FutureResponse future : workingTasks) responses.add(getTaskResponse(future)); return responses; } @@ -153,5 +153,4 @@ public class AsyncExecution { } } - } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index da34ab8822d..f6502a9801d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -2,8 +2,11 @@ package com.yahoo.searchlib.rankingexpression; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.yahoo.log.event.Collection; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; import java.security.MessageDigest; @@ -13,9 +16,14 @@ import java.util.Deque; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; /** - * A function defined by a ranking expression + * A function defined by a ranking expression, optionally containing type information + * for inputs and outputs. + * + * Immutable, but note that ranking expressions are *not* immutable. * * @author Simon Thoresen Hult * @author bratseth @@ -24,8 +32,13 @@ public class ExpressionFunction { private final String name; private final ImmutableList<String> arguments; + + /** Types of the inputs, if known. The keys here is any subset (including empty and identity) of the argument list */ + private final ImmutableMap<String, TensorType> argumentTypes; private final RankingExpression body; + private final Optional<TensorType> returnType; + /** * Constructs a new function with no arguments * @@ -44,9 +57,18 @@ public class ExpressionFunction { * @param body the ranking expression that defines this function */ public ExpressionFunction(String name, List<String> arguments, RankingExpression body) { - this.name = name; + this(name, arguments, body, ImmutableMap.of(), Optional.empty()); + } + + public ExpressionFunction(String name, List<String> arguments, RankingExpression body, + Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { + this.name = Objects.requireNonNull(name, "name cannot be null"); this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments); - this.body = body; + this.body = Objects.requireNonNull(body, "body cannot be null"); + if ( ! this.arguments.containsAll(argumentTypes.keySet())) + throw new IllegalArgumentException("Argument type keys must be a subset of the argument keys"); + this.argumentTypes = ImmutableMap.copyOf(argumentTypes); + this.returnType = Objects.requireNonNull(returnType, "returnType cannot be null"); } public String getName() { return name; } @@ -56,9 +78,27 @@ public class ExpressionFunction { public RankingExpression getBody() { return body; } + /** Returns the types of the arguments of this, if specified. The keys of this may be any subset of the arguments */ + public Map<String, TensorType> argumentTypes() { return argumentTypes; } + + /** Returns the return type of this, or empty if not specified */ + public Optional<TensorType> returnType() { return returnType; } + + public ExpressionFunction withName(String name) { + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); + } + /** Returns a copy of this with the body changed to the given value */ public ExpressionFunction withBody(RankingExpression body) { - return new ExpressionFunction(name, arguments, body); + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); + } + + public ExpressionFunction withReturnType(TensorType returnType) { + return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType)); + } + + public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) { + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); } /** diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 282a4c5e0a9..9ff391a5cfe 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -1,15 +1,22 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.regex.Pattern; /** @@ -26,12 +33,11 @@ public class ImportedModel { private final String source; private final Map<String, Signature> signatures = new HashMap<>(); - private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, TensorType> inputs = new HashMap<>(); private final Map<String, Tensor> smallConstants = new HashMap<>(); private final Map<String, Tensor> largeConstants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); private final Map<String, RankingExpression> functions = new HashMap<>(); - private final Map<String, TensorType> requiredFunctions = new HashMap<>(); /** * Creates a new imported model. @@ -49,11 +55,11 @@ public class ImportedModel { /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } - /** Returns the source path (directiry or file) of this model */ + /** Returns the source path (directory or file) of this model */ public String source() { return source; } - /** Returns an immutable map of the arguments ("Placeholders") of this */ - public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } + /** Returns an immutable map of the inputs of this */ + public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } /** * Returns an immutable map of the small constants of this. @@ -71,7 +77,7 @@ public class ImportedModel { /** * Returns an immutable map of the expressions of this - corresponding to graph nodes - * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants). + * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants). * Note that only nodes recursively referenced by a placeholder/input are added. */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } @@ -82,9 +88,6 @@ public class ImportedModel { */ public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } - /** Returns an immutable map of the functions that must be provided by the environment running this model */ - public Map<String, TensorType> requiredFunctions() { return Collections.unmodifiableMap(requiredFunctions); } - /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -96,12 +99,11 @@ public class ImportedModel { /** Convenience method for returning a default signature */ Signature defaultSignature() { return signature(defaultSignatureName); } - void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void function(String name, RankingExpression expression) { functions.put(name, expression); } - void requiredFunction(String name, TensorType type) { requiredFunctions.put(name, type); } /** * Returns all the output expressions of this indexed by name. The names consist of one or two parts @@ -109,24 +111,39 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, RankingExpression>> outputExpressions() { - List<Pair<String, RankingExpression>> expressions = new ArrayList<>(); + public List<Pair<String, ExpressionFunction>> outputExpressions() { + List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), - expressions().get(outputEntry.getValue()))); + signatureEntry.getValue().outputExpression(outputEntry.getKey()) + .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs expressions.add(new Pair<>(signatureEntry.getKey(), - expressions().get(signatureEntry.getKey()))); + new ExpressionFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().keySet()), + expressions().get(signatureEntry.getKey()), + signatureEntry.getValue().inputMap(), + Optional.empty()))); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue())); + expressions.add(new Pair<>(singleEntry.getKey(), + new ExpressionFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue(), + inputs, + Optional.empty()))); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue())); + expressions.add(new Pair<>(expressionEntry.getKey(), + new ExpressionFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue(), + inputs, + Optional.empty()))); } } } @@ -134,7 +151,7 @@ public class ImportedModel { } /** - * A signature is a set of named inputs and outputs, where the inputs maps to argument + * A signature is a set of named inputs and outputs, where the inputs maps to input * ("placeholder") names+types, and outputs maps to expressions nodes. * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit * concept of signatures. For now, we handle ONNX models as having a single signature. @@ -142,8 +159,8 @@ public class ImportedModel { public class Signature { private final String name; - private final Map<String, String> inputs = new HashMap<>(); - private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> inputs = new LinkedHashMap<>(); + private final Map<String, String> outputs = new LinkedHashMap<>(); private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); @@ -158,12 +175,20 @@ public class ImportedModel { /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name - * to argument (Placeholder) name in the owner of this + * in this signature to input name in the owning model */ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - /** Returns the type of the argument this input references */ - public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } + /** Returns the name and type of all inputs in this signature as an immutable map */ + public Map<String, TensorType> inputMap() { + ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); + for (Map.Entry<String, String> inputEntry : inputs().entrySet()) + inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue())); + return inputs.build(); + } + + /** Returns the type of the input this input references */ + public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); } /** Returns an immutable list of the expression names of this */ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } @@ -180,7 +205,13 @@ public class ImportedModel { public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } /** Returns the expression this output references */ - public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } + public ExpressionFunction outputExpression(String outputName) { + return new ExpressionFunction(outputName, + new ArrayList<>(inputs.keySet()), + owner().expressions().get(outputs.get(outputName)), + inputMap(), + Optional.empty()); + } @Override public String toString() { return "signature '" + name + "'"; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index d25502fd149..b7138ad87e3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -187,8 +187,7 @@ public abstract class ModelImporter { if (operation.isInput()) { // All inputs must have dimensions with standard naming convention: d0, d1, ... OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); - model.argument(operation.vespaName(), standardNamingConvention.type()); - model.requiredFunction(operation.vespaName(), standardNamingConvention.type()); + model.input(operation.vespaName(), standardNamingConvention.type()); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java index 917b0d6a389..e6bb5f40b3f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -2,7 +2,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; import onnx.Onnx; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 796c13a8669..94d663b4954 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.TensorType; import java.util.Collection; import java.util.Collections; @@ -80,9 +82,14 @@ public class SerializationContext extends FunctionReferenceContext { serializedFunctions.put(name, expressionString); } - /** Returns the existing serialization of a function, or null if none */ - public String getFunctionSerialization(String name) { - return serializedFunctions.get(name); + /** Adds the serialization of the an argument type to a function */ + public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { + serializedFunctions.put("rankingExpression(" + functionName + ")." + argumentName + ".type", type.toString()); + } + + /** Adds the serialization of the return type of a function */ + public void addFunctionTypeSerialization(String functionName, TensorType type) { + serializedFunctions.put("rankingExpression(" + functionName + ").type", type.toString()); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java index 118eba2cd96..969bc318391 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java @@ -15,6 +15,7 @@ import org.junit.Test; import java.io.*; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import static org.junit.Assert.fail; @@ -40,7 +41,8 @@ public class GroupingSerializationTest { t.assertMatch(new FloatResultNode(7.3)); t.assertMatch(new StringResultNode("7.3")); t.assertMatch(new StringResultNode( - new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c}))); + new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c}, + StandardCharsets.UTF_8))); t.assertMatch(new RawResultNode(new byte[]{'7', '.', '4'})); t.assertMatch(new IntegerBucketResultNode()); t.assertMatch(new FloatBucketResultNode()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index bf9684082f4..593e7b54c10 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -20,10 +21,11 @@ public class BatchNormImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - RankingExpression output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName()); - model.assertEqualResult("X", output.getName()); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName()); + model.assertEqualResult("X", output.getBody().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index a8f7542f3a4..59712c0152f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -19,22 +20,23 @@ public class DropoutImportTestCase { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); // Check required functions - assertEquals(1, model.get().requiredFunctions().size()); - assertTrue(model.get().requiredFunctions().containsKey("X")); + assertEquals(1, model.get().inputs().size()); + assertTrue(model.get().inputs().containsKey("X")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().requiredFunctions().get("X")); + model.get().inputs().get("X")); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - RankingExpression output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("outputs/Maximum", output.getName()); + assertEquals("outputs/Maximum", output.getBody().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.getRoot().toString()); - model.assertEqualResult("X", output.getName()); + output.getBody().getRoot().toString()); + model.assertEqualResult("X", output.getBody().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java index add66eece1a..3d8d5d5a570 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -20,11 +21,10 @@ public class MnistImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - RankingExpression output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/outputs/add", output.getName()); - model.assertEqualResultSum("input", output.getName(), 0.00001); + assertEquals("dnn/outputs/add", output.getBody().getName()); + model.assertEqualResultSum("input", output.getBody().getName(), 0.00001); } - } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index e20ac16a691..b6e83404ab1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -1,5 +1,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; @@ -27,27 +28,28 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor constant0 = model.largeConstants().get("test_Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), - constant0.type()); + constant0.type()); assertEquals(7840, constant0.size()); Tensor constant1 = model.largeConstants().get("test_Variable_1"); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d1", 10).build(), - constant1.type()); + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); - // Check required functions (inputs) - assertEquals(1, model.requiredFunctions().size()); - assertTrue(model.requiredFunctions().containsKey("Placeholder")); - assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredFunctions().get("Placeholder")); + // Check inputs + assertEquals(1, model.inputs().size()); + assertTrue(model.inputs().containsKey("Placeholder")); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); - // Check outputs - RankingExpression output = model.defaultSignature().outputExpression("add"); + // Check signature + ExpressionFunction output = model.defaultSignature().outputExpression("add"); assertNotNull(output); - assertEquals("add", output.getName()); + assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.getRoot().toString()); + output.getBody().getRoot().toString()); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), + model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index ef28eb4678f..0a48ecfce21 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -38,10 +39,10 @@ public class TensorFlowMnistSoftmaxImportTestCase { assertEquals(0, model.get().functions().size()); // Check required functions - assertEquals(1, model.get().requiredFunctions().size()); - assertTrue(model.get().requiredFunctions().containsKey("Placeholder")); + assertEquals(1, model.get().inputs().size()); + assertTrue(model.get().inputs().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().requiredFunctions().get("Placeholder")); + model.get().inputs().get("Placeholder")); // Check signatures assertEquals(1, model.get().signatures().size()); @@ -56,11 +57,12 @@ public class TensorFlowMnistSoftmaxImportTestCase { // ... signature outputs assertEquals(1, signature.outputs().size()); - RankingExpression output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("add", output.getName()); + assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", - output.getRoot().toString()); + output.getBody().getRoot().toString()); + assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); |