diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
15 files changed, 251 insertions, 166 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 16e494c2db1..78f61d7192d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -39,6 +39,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -128,7 +129,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Creates a global rank profile * - * @param name the name of the new profile + * @param name the name of the new profile * @param model the model owning this profile */ public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry) { @@ -231,8 +232,8 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns the a rank setting of a field, or null if there is no such rank setting in this profile * - * @param field The field whose settings to return. - * @param type The type that the field is required to be. + * @param field the field whose settings to return. + * @param type the type that the field is required to be. * @return the rank setting found, or null. */ public RankSetting getDeclaredRankSetting(String field, RankSetting.Type type) { @@ -449,7 +450,7 @@ public class RankProfile implements Serializable, Cloneable { addRankProperty(new RankProperty(name, parameter)); } - public void addRankProperty(RankProperty rankProperty) { + private void addRankProperty(RankProperty rankProperty) { // Just the usual multimap semantics here List<RankProperty> properties = rankProperties.get(rankProperty.getName()); if (properties == null) { @@ -539,17 +540,16 @@ public class RankProfile implements Serializable, Cloneable { return rankingExpressionFunction; } - /** Returns an unmodifiable view of the functions in this */ + /** Returns an unmodifiable snapshot of the functions in this */ public Map<String, RankingExpressionFunction> getFunctions() { - if (functions.size() == 0 && getInherited() == null) return Collections.emptyMap(); - if (functions.size() == 0) return getInherited().getFunctions(); - if (getInherited() == null) return Collections.unmodifiableMap(functions); + if (functions.isEmpty() && getInherited() == null) return Collections.emptyMap(); + if (functions.isEmpty()) return getInherited().getFunctions(); + if (getInherited() == null) return Collections.unmodifiableMap(new LinkedHashMap<>(functions)); // Neither is null Map<String, RankingExpressionFunction> allFunctions = new LinkedHashMap<>(getInherited().getFunctions()); allFunctions.putAll(functions); return Collections.unmodifiableMap(allFunctions); - } public int getKeepRankCount() { @@ -664,10 +664,10 @@ public class RankProfile implements Serializable, Cloneable { // Function compiling first pass: compile inline functions without resolving other functions Map<String, RankingExpressionFunction> inlineFunctions = - compileFunctions(getInlineFunctions(), queryProfiles, importedModels, Collections.emptyMap(), expressionTransforms); + compileFunctions(this::getInlineFunctions, queryProfiles, importedModels, Collections.emptyMap(), expressionTransforms); // Function compiling second pass: compile all functions and insert previously compiled inline functions - functions = compileFunctions(getFunctions(), queryProfiles, importedModels, inlineFunctions, expressionTransforms); + functions = compileFunctions(this::getFunctions, queryProfiles, importedModels, inlineFunctions, expressionTransforms); firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineFunctions, expressionTransforms); secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, importedModels, getConstants(), inlineFunctions, expressionTransforms); @@ -686,20 +686,34 @@ public class RankProfile implements Serializable, Cloneable { .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } - private Map<String, RankingExpressionFunction> compileFunctions(Map<String, RankingExpressionFunction> functions, + private Map<String, RankingExpressionFunction> compileFunctions(Supplier<Map<String, RankingExpressionFunction>> functions, QueryProfileRegistry queryProfiles, ImportedModels importedModels, Map<String, RankingExpressionFunction> inlineFunctions, ExpressionTransforms expressionTransforms) { Map<String, RankingExpressionFunction> compiledFunctions = new LinkedHashMap<>(); - for (Map.Entry<String, RankingExpressionFunction> entry : functions.entrySet()) { + Map.Entry<String, RankingExpressionFunction> entry; + // Compile all functions. Why iterate in such a complicated way? + // Because some functions (imported models adding generated macros) may add other functions during compiling. + // A straightforward iteration will either miss those functions, or may cause a ConcurrentModificationException + while (null != (entry = findUncompiledFunction(functions.get(), compiledFunctions.keySet()))) { RankingExpressionFunction rankingExpressionFunction = entry.getValue(); - RankingExpression compiled = compile(rankingExpressionFunction.function().getBody(), queryProfiles, importedModels, getConstants(), inlineFunctions, expressionTransforms); - compiledFunctions.put(entry.getKey(), rankingExpressionFunction.withBody(compiled)); + RankingExpression compiled = compile(rankingExpressionFunction.function().getBody(), queryProfiles, + importedModels, getConstants(), inlineFunctions, expressionTransforms); + compiledFunctions.put(entry.getKey(), rankingExpressionFunction.withExpression(compiled)); } return compiledFunctions; } + private Map.Entry<String, RankingExpressionFunction> findUncompiledFunction(Map<String, RankingExpressionFunction> functions, + Set<String> compiledFunctionNames) { + for (Map.Entry<String, RankingExpressionFunction> entry : functions.entrySet()) { + if ( ! compiledFunctionNames.contains(entry.getKey())) + return entry; + } + return null; + } + private RankingExpression compile(RankingExpression expression, QueryProfileRegistry queryProfiles, ImportedModels importedModels, @@ -898,7 +912,7 @@ public class RankProfile implements Serializable, Cloneable { /** A function in a rank profile */ public static class RankingExpressionFunction { - private final ExpressionFunction function; + private ExpressionFunction function; /** True if this should be inlined into calling expressions. Useful for very cheap functions. */ private final boolean inline; @@ -908,13 +922,17 @@ public class RankProfile implements Serializable, Cloneable { this.inline = inline; } + public void setReturnType(TensorType type) { + this.function = function.withReturnType(type); + } + public ExpressionFunction function() { return function; } public boolean inline() { return inline && function.arguments().isEmpty(); // only inline no-arg functions; } - public RankingExpressionFunction withBody(RankingExpression expression) { + public RankingExpressionFunction withExpression(RankingExpression expression) { return new RankingExpressionFunction(function.withBody(expression), inline); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index f42d5de21e8..a988da9664e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -59,9 +59,6 @@ public class Search implements Serializable, ImmutableSearch { // Field sets private FieldSets fieldSets = new FieldSets(); - // Whether or not this object has been processed. - private boolean processed; - // The unique name of this search definition. private String name; @@ -585,17 +582,6 @@ public class Search implements Serializable, ImmutableSearch { return false; } - public void process() { - if (processed) { - throw new IllegalStateException("Search '" + getName() + "' already processed."); - } - processed = true; - } - - public boolean isProcessed() { - return processed; - } - /** * The field set settings for this search * diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java index 3c2ebc058ac..151ad02a3fa 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java @@ -188,10 +188,6 @@ public class SearchBuilder { throw new IllegalArgumentException("Search has no name."); } String rawName = rawSearch.getName(); - if (rawSearch.isProcessed()) { - throw new IllegalArgumentException("A search definition with a search section called '" + rawName + - "' has already been processed."); - } for (Search search : searchList) { if (rawName.equals(search.getName())) { throw new IllegalArgumentException("A search definition with a search section called '" + rawName + @@ -247,8 +243,7 @@ public class SearchBuilder { DocumentModelBuilder builder = new DocumentModelBuilder(model); for (Search search : new SearchOrderer().order(searchList)) { - new FieldOperationApplierForSearch().process(search); - // These two needed for a couple of old unit tests, ideally these are just read from app + new FieldOperationApplierForSearch().process(search); // TODO: Why is this not in the regular list? process(search, deployLogger, new QueryProfiles(queryProfileRegistry), validate); built.add(search); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 9a00ee5bbd0..7c2d9a3b0ad 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -74,21 +74,11 @@ public class DerivedConfiguration { QueryProfileRegistry queryProfiles, ImportedModels importedModels) { Validator.ensureNotNull("Search definition", search); - if ( ! search.isProcessed()) { - throw new IllegalArgumentException("Search '" + search.getName() + "' not processed."); - } this.search = search; if ( ! search.isDocumentsOnly()) { streamingFields = new VsmFields(search); streamingSummary = new VsmSummary(search); } - if (abstractSearchList != null) { - for (Search abstractSearch : abstractSearchList) { - if (!abstractSearch.isProcessed()) { - throw new IllegalArgumentException("Search '" + search.getName() + "' not processed."); - } - } - } if ( ! search.isDocumentsOnly()) { attributeFields = new AttributeFields(search); summaries = new Summaries(search, deployLogger); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index c041d5c6a89..a1b0e72051b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import java.nio.charset.Charset; @@ -23,6 +24,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * A rank profile derived from a search definition, containing exactly the features available natively in the server @@ -180,32 +182,39 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) { if (functions.isEmpty()) return; - Map<String, ExpressionFunction> expressionFunctions = new LinkedHashMap<>(); - for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : functions.entrySet()) { - expressionFunctions.put(function.getKey(), function.getValue().function()); - } + + List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); Map<String, String> functionProperties = new LinkedHashMap<>(); - functionProperties.putAll(deriveFunctionProperties(expressionFunctions)); + functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions)); + if (firstPhaseRanking != null) { - functionProperties.putAll(firstPhaseRanking.getRankProperties(new ArrayList<>(expressionFunctions.values()))); + functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions)); } if (secondPhaseRanking != null) { - functionProperties.putAll(secondPhaseRanking.getRankProperties(new ArrayList<>(expressionFunctions.values()))); + functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions)); } for (Map.Entry<String, String> e : functionProperties.entrySet()) { rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue())); } - SerializationContext context = new SerializationContext(expressionFunctions.values(), null, functionProperties); + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); replaceFunctionSummaryFeatures(context); } - private Map<String, String> deriveFunctionProperties(Map<String, ExpressionFunction> functions) { - SerializationContext context = new SerializationContext(functions); - for (Map.Entry<String, ExpressionFunction> e : functions.entrySet()) { - String expression = e.getValue().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); - context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expression); + private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, + List<ExpressionFunction> functionExpressions) { + SerializationContext context = new SerializationContext(functionExpressions); + for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { + String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); + + for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) + context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); + if (e.getValue().function().returnType().isPresent()) + context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get()); + // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types + // throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved"); } return context.serializedFunctions(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 8634d51c418..ab143f77b6a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -41,11 +41,11 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans if ( ! feature.getName().equals("onnx")) return feature; try { - // TODO: Put modelPath in FeatureArguments instead - Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); + FeatureArguments arguments = asFeatureArguments(feature.getArguments()); ConvertedModel convertedModel = - convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); - return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); + convertedOnnxModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, true, context)); + return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use Onnx model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 5139d041f00..4a315420b0a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -40,10 +40,11 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); + FeatureArguments arguments = asFeatureArguments(feature.getArguments()); ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, false, context)); - return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); + convertedTensorFlowModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, false, context)); + return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index f21248b6d74..663c5afbed6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -41,10 +41,11 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr if ( ! feature.getName().equals("xgboost")) return feature; try { - Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); + FeatureArguments arguments = asFeatureArguments(feature.getArguments()); ConvertedModel convertedModel = - convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); - return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); + convertedXGBoostModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, true, context)); + return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index 3b2e29c4cb3..1af2a979cb4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -7,10 +7,8 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.processing.multifieldresolver.RankProfileTypeSettingsProcessor; import com.yahoo.vespa.model.container.search.QueryProfiles; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.List; /** * Executor of processors. This defines the right order of processor execution. @@ -76,12 +74,20 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, - RankingExpressionTypeValidator::new, + RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, IndexingValues::new); } + /** Processors of rank profiles only (those who tolerate and so something useful when the search field is null) */ + private Collection<ProcessorFactory> rankProfileProcessors() { + return Arrays.asList( + RankProfileTypeSettingsProcessor::new, + ReservedFunctionNames::new, + RankingExpressionTypeResolver::new); + } + /** * Runs all search processors on the given {@link Search} object. These will modify the search object, <b>possibly * exchanging it with another</b>, as well as its document types. @@ -94,12 +100,26 @@ public class Processing { public void process(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles, boolean validate, boolean documentsOnly) { Collection<ProcessorFactory> factories = processors(); - search.process(); factories.stream() .map(factory -> factory.create(search, deployLogger, rankProfileRegistry, queryProfiles)) .forEach(processor -> processor.process(validate, documentsOnly)); } + /** + * Runs rank profiles processors only. + * + * @param deployLogger The log to log messages and warnings for application deployment to + * @param rankProfileRegistry a {@link com.yahoo.searchdefinition.RankProfileRegistry} + * @param queryProfiles The query profiles contained in the application this search is part of. + */ + public void processRankProfiles(DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles, boolean validate, boolean documentsOnly) { + Collection<ProcessorFactory> factories = rankProfileProcessors(); + factories.stream() + .map(factory -> factory.create(null, deployLogger, rankProfileRegistry, queryProfiles)) + .forEach(processor -> processor.process(validate, documentsOnly)); + } + @FunctionalInterface public interface ProcessorFactory { Processor create(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java index 102d1910360..4c8b5910b78 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java @@ -7,39 +7,44 @@ import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.vespa.model.container.search.QueryProfiles; +import java.util.Map; + /** - * Validates the types of all ranking expressions under a search instance: + * Resolves and assigns types to all functions in a ranking expression, and + * validates the types of all ranking expressions under a search instance: * Some operators constrain the types of inputs, and first-and second-phase expressions - * must return scalar values. In addition, the existence of all referred attribute, query and constant + * must return scalar values. + * + * In addition, the existence of all referred attribute, query and constant * features is ensured. * * @author bratseth */ -public class RankingExpressionTypeValidator extends Processor { +public class RankingExpressionTypeResolver extends Processor { private final QueryProfileRegistry queryProfiles; - public RankingExpressionTypeValidator(Search search, - DeployLogger deployLogger, - RankProfileRegistry rankProfileRegistry, - QueryProfiles queryProfiles) { + public RankingExpressionTypeResolver(Search search, + DeployLogger deployLogger, + RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles) { super(search, deployLogger, rankProfileRegistry, queryProfiles); this.queryProfiles = queryProfiles.getRegistry(); } @Override public void process(boolean validate, boolean documentsOnly) { - if ( ! validate) return; if (documentsOnly) return; for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { try { - validate(profile); + resolveTypesIn(profile, validate); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("In " + search + ", " + profile, e); @@ -47,20 +52,34 @@ public class RankingExpressionTypeValidator extends Processor { } } - /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */ - private void validate(RankProfile profile) { - TypeContext context = profile.typeContext(queryProfiles); - profile.getSummaryFeatures().forEach(f -> ensureValid(f, "summary feature " + f, context)); - ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context); - ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context); + /** + * Resolves the types of all functions in the given profile + * + * @throws IllegalArgumentException if validate is true and the given rank profile does not produce valid types + */ + private void resolveTypesIn(RankProfile profile, boolean validate) { + TypeContext<Reference> context = profile.typeContext(queryProfiles); + for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { + if ( ! function.getValue().function().arguments().isEmpty()) continue; + TensorType type = resolveType(function.getValue().function().getBody(), + "function '" + function.getKey() + "'", + context); + function.getValue().setReturnType(type); + } + + if (validate) { + profile.getSummaryFeatures().forEach(f -> resolveType(f, "summary feature " + f, context)); + ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context); + ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context); + } } - private TensorType ensureValid(RankingExpression expression, String expressionDescription, TypeContext context) { + private TensorType resolveType(RankingExpression expression, String expressionDescription, TypeContext context) { if (expression == null) return null; - return ensureValid(expression.getRoot(), expressionDescription, context); + return resolveType(expression.getRoot(), expressionDescription, context); } - private TensorType ensureValid(ExpressionNode expression, String expressionDescription, TypeContext context) { + private TensorType resolveType(ExpressionNode expression, String expressionDescription, TypeContext context) { TensorType type; try { type = expression.type(context); @@ -75,7 +94,7 @@ public class RankingExpressionTypeValidator extends Processor { private void ensureValidDouble(RankingExpression expression, String expressionDescription, TypeContext context) { if (expression == null) return; - TensorType type = ensureValid(expression, expressionDescription, context); + TensorType type = resolveType(expression, expressionDescription, context); if ( ! type.equals(TensorType.empty)) throw new IllegalArgumentException("The " + expressionDescription + " must produce a double " + "(a tensor with no dimensions), but produces " + type); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java index ec4cbdfe58b..3bde76c1c79 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/multifieldresolver/RankProfileTypeSettingsProcessor.java @@ -44,6 +44,7 @@ public class RankProfileTypeSettingsProcessor extends Processor { } private void processAttributeFields() { + if (search == null) return; // we're processing global profiles for (SDField field : search.allConcreteFields()) { Attribute attribute = field.getAttributes().get(field.getName()); if (attribute != null && attribute.tensorType().isPresent()) { @@ -53,6 +54,7 @@ public class RankProfileTypeSettingsProcessor extends Processor { } private void processImportedFields() { + if (search == null) return; // we're processing global profiles Optional<ImportedFields> importedFields = search.importedFields(); if (importedFields.isPresent()) { importedFields.get().fields().forEach((fieldName, field) -> processImportedField(field)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/Service.java b/config-model/src/main/java/com/yahoo/vespa/model/Service.java index 620e44bc11a..29ec26b06d2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/Service.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/Service.java @@ -7,7 +7,7 @@ import java.util.HashMap; import java.util.Optional; /** - * Representation of a process which runs a service + * Representation of a markProcessed which runs a service * * @author gjoranv */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 4b70b1b5ae2..d80799d4390 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -32,7 +32,9 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstants; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RankProfileList; +import com.yahoo.searchdefinition.processing.Processing; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.ml.ConvertedModel; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; @@ -168,7 +170,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri createGlobalRankProfiles(deployState.getImportedModels(), deployState.rankProfileRegistry(), - deployState.getQueryProfiles().getRegistry()); + deployState.getQueryProfiles()); this.rankProfileList = new RankProfileList(null, // null search -> global rankingConstants, AttributeFields.empty, @@ -219,41 +221,39 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri /** Adds generic application specific clusters of services */ private void addServiceClusters(ApplicationPackage app, VespaModelBuilder builder) { - for (ServiceCluster sc : builder.getClusters(app, this)) - serviceClusters.add(sc); + serviceClusters.addAll(builder.getClusters(app, this)); } /** - * Creates a rank profile not attached to any search definition, for each imported model in the application package + * Creates a rank profile not attached to any search definition, for each imported model in the application package, + * and adds it to the given rank profile registry. */ - private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels, - RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfiles) { - List<RankProfile> profiles = new ArrayList<>(); + private void createGlobalRankProfiles(ImportedModels importedModels, + RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles) { if ( ! importedModels.all().isEmpty()) { // models/ directory is available for (ImportedModel model : importedModels.all()) { RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), - model.name(), profile, queryProfiles, model); - for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue()), false); - } + model.name(), profile, queryProfiles.getRegistry(), model); + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); } } else { // generated and stored model information may be available instead ApplicationFile generatedModelsDir = applicationPackage.getFile(ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR); for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) { String modelName = generatedModelDir.getPath().last(); + if (modelName.contains(".")) continue; // Name space: Not a global profile RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); - for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue()), false); - } + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); } } - return ImmutableList.copyOf(profiles); + new Processing().processRankProfiles(deployState.getDeployLogger(), + rankProfileRegistry, + queryProfiles, true, false); } /** Returns the global rank profiles as a rank profile list */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index adf5c81283e..30586b1e677 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -48,6 +48,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,14 +68,14 @@ public class ConvertedModel { private final ModelName modelName; private final String modelDescription; - private final ImmutableMap<String, RankingExpression> expressions; + private final ImmutableMap<String, ExpressionFunction> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ private final Optional<ImportedModel> sourceModel; private ConvertedModel(ModelName modelName, String modelDescription, - Map<String, RankingExpression> expressions, + Map<String, ExpressionFunction> expressions, Optional<ImportedModel> sourceModel) { this.modelName = modelName; this.modelDescription = modelDescription; @@ -90,18 +91,26 @@ public class ConvertedModel { * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory */ public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { - File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath); + ImportedModel sourceModel = // TODO: Convert to name here, make sure its done just one way + context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath)); ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile); - if (sourceModel.exists()) + + if (sourceModel == null && ! new ModelStore(context.rankProfile().applicationPackage(), modelName).exists()) + throw new IllegalArgumentException("No model '" + modelPath + "' is available. Available models: " + + context.importedModels().all().stream().map(ImportedModel::source).collect(Collectors.joining(", "))); + + if (sourceModel != null) { return fromSource(modelName, modelPath.toString(), context.rankProfile(), context.queryProfiles(), - context.importedModels().get(sourceModel)); // TODO: Convert to name here, make sure its done just one way - else + sourceModel); + } + else { return fromStore(modelName, modelPath.toString(), context.rankProfile()); + } } public static ConvertedModel fromSource(ModelName modelName, @@ -132,23 +141,23 @@ public class ConvertedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public Map<String, RankingExpression> expressions() { return expressions; } + public Map<String, ExpressionFunction> expressions() { return expressions; } /** * Returns the expression matching the given arguments. */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { - RankingExpression expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we can verify - verifyRequiredFunctions(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles()); - return expression.getRoot(); + ExpressionFunction expression = selectExpression(arguments); + if (sourceModel.isPresent()) // we should verify + verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); + return expression.getBody().getRoot(); } - private RankingExpression selectExpression(FeatureArguments arguments) { + private ExpressionFunction selectExpression(FeatureArguments arguments) { if (expressions.isEmpty()) throw new IllegalArgumentException("No expressions available in " + this); - RankingExpression expression = expressions.get(arguments.toName()); + ExpressionFunction expression = expressions.get(arguments.toName()); if (expression != null) return expression; if ( ! arguments.signature().isPresent()) { @@ -158,7 +167,7 @@ public class ConvertedModel { } if ( ! arguments.output().isPresent()) { - List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix = + List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix = expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList()); if (entriesWithTheRightPrefix.size() < 1) throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + @@ -179,10 +188,10 @@ public class ConvertedModel { // ----------------------- Static model conversion/storage below here - private static Map<String, RankingExpression> convertAndStore(ImportedModel model, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelStore store) { + private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelStore store) { // Add constants Set<String> constantsReplacedByFunctions = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -193,8 +202,8 @@ public class ConvertedModel { addGeneratedFunctions(model, profile); // Add expressions - Map<String, RankingExpression> expressions = new HashMap<>(); - for (Pair<String, RankingExpression> output : model.outputExpressions()) { + Map<String, ExpressionFunction> expressions = new HashMap<>(); + for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { addExpression(output.getSecond(), output.getFirst(), constantsReplacedByFunctions, model, store, profile, queryProfiles, @@ -210,21 +219,21 @@ public class ConvertedModel { return expressions; } - private static void addExpression(RankingExpression expression, + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, ImportedModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Map<String, RankingExpression> expressions) { - expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions); - reduceBatchDimensions(expression, model, profile, queryProfiles); + Map<String, ExpressionFunction> expressions) { + expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); + reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } - private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) { + private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -290,15 +299,15 @@ public class ConvertedModel { } /** - * Verify that the functions referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredFunctions. + * Verify that the inputs declared in the given expression exists in the given rank profile as functions, + * and return tensors of the correct types. */ - private static void verifyRequiredFunctions(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { + private static void verifyInputs(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { - TensorType requiredType = model.requiredFunctions().get(functionName); + TensorType requiredType = model.inputs().get(functionName); if (requiredType == null) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); @@ -375,7 +384,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredFunctions().containsKey(referenceNode.getName())) { + if (model.inputs().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -383,7 +392,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredFunctions().containsKey(referenceNode.getName())) { + if (model.inputs().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -451,7 +460,8 @@ public class ConvertedModel { Set<String> constantsReplacedByFunctions) { if (constantsReplacedByFunctions.isEmpty()) return expression; return new RankingExpression(expression.getName(), - replaceConstantsByFunctions(expression.getRoot(), constantsReplacedByFunctions)); + replaceConstantsByFunctions(expression.getRoot(), + constantsReplacedByFunctions)); } private static ExpressionNode replaceConstantsByFunctions(ExpressionNode node, Set<String> constantsReplacedByFunctions) { @@ -518,25 +528,32 @@ public class ConvertedModel { this.modelFiles = new ModelFiles(modelName); } + /** Returns whether a model store for this application and model name exists */ + public boolean exists() { + return application.getFile(modelFiles.storedModelReplicatedPath()).exists(); + } + /** * Adds this expression to the application package, such that it can be read later. * * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part * is always the model name */ - void writeExpression(String name, RankingExpression expression) { - application.getFile(modelFiles.expressionPath(name)) - .writeFile(new StringReader(expression.getRoot().toString())); + void writeExpression(String name, ExpressionFunction expression) { + StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString()); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) + b.append('\n').append(input.getKey()).append('\t').append(input.getValue()); + application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString())); } - Map<String, RankingExpression> readExpressions() { - Map<String, RankingExpression> expressions = new HashMap<>(); + Map<String, ExpressionFunction> readExpressions() { + Map<String, ExpressionFunction> expressions = new HashMap<>(); ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); for (ApplicationFile expressionFile : expressionPath.listFiles()) { - try (Reader reader = new BufferedReader(expressionFile.createReader())){ + try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){ String name = expressionFile.getPath().getName(); - expressions.put(name, new RankingExpression(name, reader)); + expressions.put(name, readExpression(name, reader)); } catch (IOException e) { throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e); @@ -548,8 +565,22 @@ public class ConvertedModel { return expressions; } + private ExpressionFunction readExpression(String name, BufferedReader reader) + throws IOException, ParseException { + // First line is expression + RankingExpression expression = new RankingExpression(name, reader.readLine()); + // Next lines are inputs on the format name\ttensorTypeSpec + Map<String, TensorType> inputs = new LinkedHashMap<>(); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + inputs.put(parts[0], TensorType.fromSpec(parts[1])); + } + return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty()); + } + /** Adds this function expression to the application package so it can be read later. */ - void writeFunction(String name, RankingExpression expression) { + public void writeFunction(String name, RankingExpression expression) { application.getFile(modelFiles.functionsPath()).appendFile(name + "\t" + expression.getRoot().toString() + "\n"); } @@ -561,20 +592,20 @@ public class ConvertedModel { if ( ! file.exists()) return Collections.emptyList(); List<Pair<String, RankingExpression>> functions = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[0], parts[1]); - functions.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + name, e); + try (BufferedReader reader = new BufferedReader(file.createReader())) { + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[0], parts[1]); + functions.add(new Pair<>(name, expression)); + } catch (ParseException e) { + throw new IllegalStateException("Could not parse " + name, e); + } } + return functions; } - return functions; } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java index fda49af6178..4a02dc97d19 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -14,19 +15,24 @@ import java.util.Optional; */ public class FeatureArguments { + private final Path path; + /** Optional arguments */ private final Optional<String> signature, output; public FeatureArguments(Arguments arguments) { - this(optionalArgument(1, arguments), + this(Path.fromString(argument(0, arguments)), + optionalArgument(1, arguments), optionalArgument(2, arguments)); } - public FeatureArguments(Optional<String> signature, Optional<String> output) { + private FeatureArguments(Path path, Optional<String> signature, Optional<String> output) { + this.path = path; this.signature = signature; this.output = output; } + public Path path() { return path; } public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } @@ -35,13 +41,20 @@ public class FeatureArguments { (output.isPresent() ? "." + output.get() : ""); } + private static String argument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + throw new IllegalArgumentException("Requires at least " + argumentIndex + + " arguments, but got just " + arguments.size()); + return asString(arguments.expressions().get(argumentIndex)); + } + private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { if (argumentIndex >= arguments.expressions().size()) return Optional.empty(); return Optional.of(asString(arguments.expressions().get(argumentIndex))); } - public static String asString(ExpressionNode node) { + private static String asString(ExpressionNode node) { if ( ! (node instanceof ConstantNode)) throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); return stripQuotes(((ConstantNode)node).sourceString()); |