diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-02-20 16:58:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-20 16:58:07 +0100 |
commit | 7cbcd92168f36a63f0dade4acc5683e134e9ac48 (patch) | |
tree | a15aaa05bb4d0592d655cd5c8c57fe02f258a39e | |
parent | 3cb51aa803fef2ab0d622768ff623a80691d6811 (diff) | |
parent | bf9358e1c983ca3b2c4f9630873ed4e53634236f (diff) |
Merge pull request #5065 from vespa-engine/bratseth/typecheck-all-2
Bratseth/typecheck all 2
74 files changed, 1191 insertions, 470 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index d6b916680d8..bd94f67e4a7 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -323,7 +323,7 @@ public class DeployState implements ConfigDefinitionStore { closeIgnoreException(reader.getReader()); } } - builder.build(logger, queryProfiles); + builder.build(logger); return SearchDocumentModel.fromBuilderAndNames(builder, names); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java index dd03cb8b2a7..dc59d9cb3e5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java @@ -5,11 +5,10 @@ */ package com.yahoo.searchdefinition; -import java.util.Arrays; -import java.util.List; +import com.yahoo.searchlib.rankingexpression.Reference; + import java.util.Optional; import java.util.regex.Pattern; -import java.util.stream.Collectors; /** * Utility methods for query, document and constant rank feature names @@ -20,85 +19,16 @@ public class FeatureNames { private static final Pattern identifierRegexp = Pattern.compile("[A-Za-z0-9_][A-Za-z0-9_-]*"); - /** - * <p>Returns the given query, document or constant feature in canonical form. - * A feature name consists of a feature type name (query, attribute or constant), - * followed by one argument enclosed in quotes. - * The argument may be an identifier or any string single or double quoted.</p> - * - * <p>Argument string values may not contain comma, single quote nor double quote characters.</p> - * - * <p><i>The canonical form use no quotes for arguments which are identifiers, and double quotes otherwise.</i></p> - * - * <p>Note that the above definition is not true for features in general, which accept any ranking expression - * as argument.</p> - * - * @throws IllegalArgumentException if the feature name is not valid - */ - // Note that this implementation is more general than what is described above: - // It accepts any number of arguments and an optional output - public static String canonicalize(String feature) { - return canonicalizeIfValid(feature).orElseThrow(() -> - new IllegalArgumentException("A feature name must be on the form query(name), attribute(name) or " + - "constant(name), but was '" + feature + "'" - )); - } - - /** - * Canonicalizes the given argument as in canonicalize, but returns empty instead of throwing an exception if - * the argument is not a valid feature - */ - public static Optional<String> canonicalizeIfValid(String feature) { - int startParenthesis = feature.indexOf('('); - if (startParenthesis < 0) - return Optional.empty(); - int endParenthesis = feature.lastIndexOf(')'); - String featureType = feature.substring(0, startParenthesis); - if ( ! ( featureType.equals("query") || featureType.equals("attribute") || featureType.equals("constant"))) - return Optional.empty(); - if (startParenthesis < 1) return Optional.of(feature); // No arguments - if (endParenthesis < startParenthesis) - return Optional.empty(); - String argumentString = feature.substring(startParenthesis + 1, endParenthesis); - List<String> canonicalizedArguments = - Arrays.stream(argumentString.split(",")) - .map(FeatureNames::canonicalizeArgument) - .collect(Collectors.toList()); - return Optional.of(featureType + "(" + - canonicalizedArguments.stream().collect(Collectors.joining(",")) + - feature.substring(endParenthesis)); - } - - /** Canomicalizes a single argument */ - private static String canonicalizeArgument(String argument) { - if (argument.startsWith("'")) { - if ( ! argument.endsWith("'")) - throw new IllegalArgumentException("Feature arguments starting by a single quote " + - "must end by a single quote, but was \"" + argument + "\""); - argument = argument.substring(1, argument.length() - 1); - } - if (argument.startsWith("\"")) { - if ( ! argument.endsWith("\"")) - throw new IllegalArgumentException("Feature arguments starting by a double quote " + - "must end by a double quote, but was '" + argument + "'"); - argument = argument.substring(1, argument.length() - 1); - } - if (identifierRegexp.matcher(argument).matches()) - return argument; - else - return "\"" + argument + "\""; - } - - public static String asConstantFeature(String constantName) { - return canonicalize("constant(\"" + constantName + "\")"); + public static Reference asConstantFeature(String constantName) { + return Reference.simple("constant", quoteIfNecessary(constantName)); } - public static String asAttributeFeature(String attributeName) { - return canonicalize("attribute(\"" + attributeName + "\")"); + public static Reference asAttributeFeature(String attributeName) { + return Reference.simple("attribute", quoteIfNecessary(attributeName)); } - public static String asQueryFeature(String propertyName) { - return canonicalize("query(\"" + propertyName + "\")"); + public static Reference asQueryFeature(String propertyName) { + return Reference.simple("query", quoteIfNecessary(propertyName)); } /** @@ -106,15 +36,21 @@ public class FeatureNames { * or empty if it is not a valid query, attribute or constant feature name */ public static Optional<String> argumentOf(String feature) { - return canonicalizeIfValid(feature).map(f -> { - int startParenthesis = f.indexOf("("); - int endParenthesis = f.indexOf(")"); - String possiblyQuotedArgument = f.substring(startParenthesis + 1, endParenthesis); - if (possiblyQuotedArgument.startsWith("\"")) - return possiblyQuotedArgument.substring(1, possiblyQuotedArgument.length() - 1); - else - return possiblyQuotedArgument; - }); + Optional<Reference> reference = Reference.simple(feature); + if ( ! reference.isPresent()) return Optional.empty(); + if ( ! ( reference.get().name().equals("attribute") || + reference.get().name().equals("constant") || + reference.get().name().equals("query"))) + return Optional.empty(); + + return Optional.of(reference.get().arguments().expressions().get(0).toString()); + } + + private static String quoteIfNecessary(String s) { + if (identifierRegexp.matcher(s).matches()) + return s; + else + return "\"" + s + "\""; } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java new file mode 100644 index 00000000000..fcae756eab3 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -0,0 +1,156 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * A context which only contains type information. + * This returns empty tensor types (double) for unknown features which are not + * query, attribute or constant features, as we do not have information about which such + * features exist (but we know those that exist are doubles). + * + * @author bratseth + */ +public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> { + + private final Map<Reference, TensorType> featureTypes = new HashMap<>(); + + public MapEvaluationTypeContext(Collection<ExpressionFunction> functions) { + super(functions); + } + + public MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, + Map<String, String> bindings, + Map<Reference, TensorType> featureTypes) { + super(functions, bindings); + this.featureTypes.putAll(featureTypes); + } + + public void setType(Reference reference, TensorType type) { + featureTypes.put(reference, type); + } + + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + + @Override + public TensorType getType(Reference reference) { + Optional<String> binding = boundIdentifier(reference); + if (binding.isPresent()) { + try { + // This is not pretty, but changing to bind expressions rather + // than their string values requires deeper changes + return new RankingExpression(binding.get()).type(this); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + if (isSimpleFeature(reference)) { + // The argument may be a local identifier bound to the actual value + String argument = simpleArgument(reference.arguments()).get(); + reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); + return featureTypes.get(reference); + } + + Optional<ExpressionFunction> function = functionInvocation(reference); + if (function.isPresent()) { + return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); + } + + // We do not know what this is - since we do not have complete knowledge abut the match features + // in Java we must assume this is a match feature and return the double type - which is the type of all + // all match features + return TensorType.empty; + } + + /** + * Returns the binding if this reference is a simple identifier which is bound in this context. + * Returns empty otherwise. + */ + private Optional<String> boundIdentifier(Reference reference) { + if ( ! reference.arguments().isEmpty()) return Optional.empty(); + if ( reference.output() != null) return Optional.empty(); + return Optional.ofNullable(bindings.get(reference.name())); + } + + /** + * Return whether the reference (discarding the output) is a simple feature + * ("attribute(name)", "constant(name)" or "query(name)"). + * We disregard the output because all outputs under a simple feature have the same type. + */ + private boolean isSimpleFeature(Reference reference) { + Optional<String> argument = simpleArgument(reference.arguments()); + if ( ! argument.isPresent()) return false; + return reference.name().equals("attribute") || + reference.name().equals("constant") || + reference.name().equals("query"); + } + + /** + * If these arguments contains one simple argument string, it is returned. + * Otherwise null is returned. + */ + private Optional<String> simpleArgument(Arguments arguments) { + if (arguments.expressions().size() != 1) return Optional.empty(); + ExpressionNode argument = arguments.expressions().get(0); + + if ( ! (argument instanceof ReferenceNode)) return Optional.empty(); + ReferenceNode refArgument = (ReferenceNode)argument; + + if ( ! refArgument.reference().isIdentifier()) return Optional.empty(); + + return Optional.of(refArgument.getName()); + } + + private Optional<ExpressionFunction> functionInvocation(Reference reference) { + if (reference.output() != null) return Optional.empty(); + ExpressionFunction function = functions().get(reference.name()); + if (function == null) return Optional.empty(); + if (function.arguments().size() != reference.arguments().size()) return Optional.empty(); + return Optional.of(function); + } + + /** Binds the given list of formal arguments to their actual values */ + private Map<String, String> bind(List<String> formalArguments, + Arguments invocationArguments) { + Map<String, String> bindings = new HashMap<>(formalArguments.size()); + for (int i = 0; i < formalArguments.size(); i++) { + String identifier = invocationArguments.expressions().get(i).toString(); + identifier = super.bindings.getOrDefault(identifier, identifier); + bindings.put(formalArguments.get(i), identifier); + } + return bindings; + } + + public Map<Reference, TensorType> featureTypes() { + return Collections.unmodifiableMap(featureTypes); + } + + @Override + public MapEvaluationTypeContext withBindings(Map<String, String> bindings) { + if (bindings.isEmpty() && this.bindings.isEmpty()) return this; + return new MapEvaluationTypeContext(functions(), bindings, featureTypes); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index bcbc7cc99e2..bd645422d50 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -2,15 +2,9 @@ package com.yahoo.searchdefinition; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.deploy.DeployState; -import com.yahoo.io.reader.NamedReader; -import com.yahoo.processing.request.CompoundName; -import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.search.query.profile.config.QueryProfileXMLReader; import com.yahoo.search.query.profile.types.FieldDescription; import com.yahoo.search.query.profile.types.QueryProfileType; -import com.yahoo.search.query.profile.types.TensorFieldType; import com.yahoo.search.query.ranking.Diversity; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; @@ -18,8 +12,8 @@ import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; @@ -39,7 +33,9 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * Represents a rank profile - a named set of ranking settings @@ -363,14 +359,14 @@ public class RankProfile implements Serializable, Cloneable { /** Returns a read-only view of the summary features to use in this profile. This is never null */ public Set<ReferenceNode> getSummaryFeatures() { - if (summaryFeatures!=null) return Collections.unmodifiableSet(summaryFeatures); - if (getInherited()!=null) return getInherited().getSummaryFeatures(); + if (summaryFeatures != null) return Collections.unmodifiableSet(summaryFeatures); + if (getInherited() != null) return getInherited().getSummaryFeatures(); return Collections.emptySet(); } public void addSummaryFeature(ReferenceNode feature) { - if (summaryFeatures==null) - summaryFeatures=new LinkedHashSet<>(); + if (summaryFeatures == null) + summaryFeatures = new LinkedHashSet<>(); summaryFeatures.add(feature); } @@ -585,8 +581,11 @@ public class RankProfile implements Serializable, Cloneable { } /** - * Will take the parser-set textual ranking expressions and turn into objects + * Will take the parser-set textual ranking expressions and turn into ranking expression objects, + * if not already done */ + // TODO: There doesn't appear to be any good reason to defer parsing of ranking expressions + // until this is called. Simplify by parsing them right away. public void parseExpressions() { try { parseRankingExpressions(); @@ -604,20 +603,23 @@ public class RankProfile implements Serializable, Cloneable { for (Map.Entry<String, Macro> e : getMacros().entrySet()) { String macroName = e.getKey(); Macro macro = e.getValue(); - RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression()); - macro.setRankingExpression(expr); - macro.setTextualExpression(expr.getRoot().toString()); + if (macro.getRankingExpression() == null) { + RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression()); + macro.setRankingExpression(expr); + macro.setTextualExpression(expr.getRoot().toString()); + } } } /** * Passes ranking expressions on to parser + * * @throws ParseException if either of the ranking expressions could not be parsed */ private void parseRankingExpressions() throws ParseException { - if (getFirstPhaseRankingString() != null) + if (getFirstPhaseRankingString() != null && firstPhaseRanking == null) setFirstPhaseRanking(parseRankingExpression("firstphase", getFirstPhaseRankingString())); - if (getSecondPhaseRankingString() != null) + if (getSecondPhaseRankingString() != null && secondPhaseRanking == null) setSecondPhaseRanking(parseRankingExpression("secondphase", getSecondPhaseRankingString())); } @@ -748,7 +750,9 @@ public class RankProfile implements Serializable, Cloneable { * referable from this rank profile. */ public TypeContext typeContext(QueryProfileRegistry queryProfiles) { - TypeMapContext context = new TypeMapContext(); + MapEvaluationTypeContext context = new MapEvaluationTypeContext(getMacros().values().stream() + .map(Macro::asExpressionFunction) + .collect(Collectors.toList())); // Add small constants getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type())); @@ -764,15 +768,18 @@ public class RankProfile implements Serializable, Cloneable { for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) { for (FieldDescription field : queryProfileType.declaredFields().values()) { TensorType type = field.getType().asTensorType(); - String feature = FeatureNames.asQueryFeature(field.getName()); - TensorType existingType = context.getType(feature); + Optional<Reference> feature = Reference.simple(field.getName()); + if ( ! feature.isPresent() || ! feature.get().name().equals("query")) continue; + + TensorType existingType = context.getType(feature.get()); if (existingType != null) type = existingType.dimensionwiseGeneralizationWith(type).orElseThrow( () -> new IllegalArgumentException(queryProfileType + " contains query feature " + feature + " with type " + field.getType().asTensorType() + ", but this is already defined " + - "in another query profile with type " + context.getType(feature))); - context.setType(feature, type); + "in another query profile with type " + + context.getType(feature.get()))); + context.setType(feature.get(), type); } } @@ -910,7 +917,7 @@ public class RankProfile implements Serializable, Cloneable { */ public static class Macro implements Serializable, Cloneable { - private String name=null; + private final String name; private String textualExpression=null; private RankingExpression expression=null; private List<String> formalParams = new ArrayList<>(); @@ -955,7 +962,7 @@ public class RankProfile implements Serializable, Cloneable { return inline && formalParams.size() == 0; // only inline no-arg macros; } - public ExpressionFunction toExpressionMacro() { + public ExpressionFunction asExpressionFunction() { return new ExpressionFunction(getName(), getFormalParams(), getRankingExpression()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java index 762c0fec838..e7cd21ac834 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java @@ -18,6 +18,7 @@ import com.yahoo.searchdefinition.parser.TokenMgrError; import com.yahoo.searchdefinition.processing.Processing; import com.yahoo.vespa.documentmodel.DocumentModel; import com.yahoo.vespa.model.container.search.QueryProfiles; +import com.yahoo.yolean.Exceptions; import java.io.File; import java.io.IOException; @@ -34,7 +35,6 @@ import java.util.List; * expressions, using the setRankXXX() methods, 3) invoke the {@link #build()} method, and 4) retrieve the built * search objects using the {@link #getSearch(String)} method. */ -// TODO: This should be cleaned up and more or maybe completely taken over by MockApplicationPackage public class SearchBuilder { private final DocumentTypeManager docTypeMgr = new DocumentTypeManager(); @@ -154,7 +154,7 @@ public class SearchBuilder { } catch (TokenMgrError e) { throw new ParseException("Unknown symbol: " + e.getMessage()); } catch (ParseException pe) { - throw new ParseException(stream.formatException(pe.getMessage())); + throw new ParseException(stream.formatException(Exceptions.toMessageString(pe))); } return importRawSearch(search); } @@ -196,11 +196,7 @@ public class SearchBuilder { * @throws IllegalStateException Thrown if this method has already been called. */ public void build() { - build(new BaseDeployLogger(), new QueryProfiles()); - } - - public void build(DeployLogger logger) { - build(logger, new QueryProfiles()); + build(new BaseDeployLogger()); } /** @@ -209,12 +205,10 @@ public class SearchBuilder { * * @throws IllegalStateException Thrown if this method has already been called. * @param deployLogger The logger to use during build - * @param queryProfiles The query profiles contained in the application this search is part of. */ - public void build(DeployLogger deployLogger, QueryProfiles queryProfiles) { - if (isBuilt) { - throw new IllegalStateException("Searches already built."); - } + public void build(DeployLogger deployLogger) { + if (isBuilt) throw new IllegalStateException("Model already built"); + List<Search> built = new ArrayList<>(); List<SDDocumentType> sdocs = new ArrayList<>(); sdocs.add(SDDocumentType.VESPA_DOCUMENT); @@ -240,7 +234,7 @@ public class SearchBuilder { for (Search search : new SearchOrderer().order(searchList)) { new FieldOperationApplierForSearch().process(search); // These two needed for a couple of old unit tests, ideally these are just read from app - process(search, deployLogger, queryProfiles); + process(search, deployLogger, new QueryProfiles(queryProfileRegistry)); built.add(search); } builder.addToModel(searchList); @@ -254,8 +248,6 @@ public class SearchBuilder { /** * Processes and returns the given {@link Search} object. This method has been factored out of the {@link * #build()} method so that subclasses can choose not to build anything. - * - * @param search The object to build. */ protected void process(Search search, DeployLogger deployLogger, QueryProfiles queryProfiles) { Processing.process(search, deployLogger, rankProfileRegistry, queryProfiles); @@ -352,7 +344,7 @@ public class SearchBuilder { rankProfileRegistry, queryprofileRegistry); builder.importFile(fileName); - builder.build(deployLogger, new QueryProfiles()); + builder.build(deployLogger); return builder; } @@ -368,7 +360,7 @@ public class SearchBuilder { for (Iterator<Path> i = Files.list(new File(dir).toPath()).filter(p -> p.getFileName().toString().endsWith(".sd")).iterator(); i.hasNext(); ) { builder.importFile(i.next()); } - builder.build(new BaseDeployLogger(), new QueryProfiles()); + builder.build(new BaseDeployLogger()); return builder; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java deleted file mode 100644 index 40e9db1413f..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -/** - * A context which only contains type information. - * - * @author bratseth - */ -public class TypeMapContext implements TypeContext { - - private final Map<String, TensorType> featureTypes = new HashMap<>(); - - public void setType(String name, TensorType type) { - featureTypes.put(FeatureNames.canonicalize(name), type); - } - - @Override - public TensorType getType(String name) { - return featureTypes.get(FeatureNames.canonicalize(name)); - } - - /** Returns an unmodifiable map of the bindings in this */ - public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index ea02f960800..b02362154d9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -188,7 +188,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (macros.isEmpty()) return; Map<String, ExpressionFunction> expressionMacros = new LinkedHashMap<>(); for (Map.Entry<String, RankProfile.Macro> macro : macros.entrySet()) { - expressionMacros.put(macro.getKey(), macro.getValue().toExpressionMacro()); + expressionMacros.put(macro.getKey(), macro.getValue().asExpressionFunction()); } Map<String, String> macroProperties = new LinkedHashMap<>(); @@ -223,7 +223,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { // Is the feature a macro? if (context.getFunction(referenceNode.getName()) != null) { context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()), - referenceNode.toString(context, null, null)); + referenceNode.toString(context, null, null)); ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput()); macroSummaryFeatures.put(referenceNode.getName(), newReferenceNode); i.remove(); // Will add the expanded one in next block diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 2b997aa25f2..f16697b5ba6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -208,6 +208,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil throw new IllegalArgumentException("Model refers Placeholder '" + macroName + "' of type " + requiredType + " but this macro is not present in " + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) throw new IllegalArgumentException("Model refers Placeholder '" + macroName + diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java index ee65c9bec02..cc634abef01 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java @@ -13,7 +13,7 @@ import com.yahoo.vespa.indexinglanguage.expressions.OutputExpression; import com.yahoo.vespa.model.container.search.QueryProfiles; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen Hult</a> + * @author Simon Thoresen Hult */ public class IndexingValues extends Processor { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index 90183848094..061a803cb48 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -76,8 +76,9 @@ public class Processing { ImportedFieldsInSummayValidator::new, FastAccessValidator::new, ReservedMacroNames::new, + RankingExpressionTypeValidator::new, - // These two should be last. + // These should be last. IndexingValidation::new, IndexingValues::new); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java new file mode 100644 index 00000000000..fa2610d77a1 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java @@ -0,0 +1,82 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.vespa.model.container.search.QueryProfiles; + +/** + * Validates the types of all ranking expressions under a search instance: + * Some operators constrain the types of inputs, and first-and second-phase expressions + * must return scalar values. In addition, the existence of all referred attribute, query and constant + * features is ensured. + * + * @author bratseth + */ +public class RankingExpressionTypeValidator extends Processor { + + private final QueryProfileRegistry queryProfiles; + + public RankingExpressionTypeValidator(Search search, + DeployLogger deployLogger, + RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + this.queryProfiles = queryProfiles.getRegistry(); + } + + @Override + public void process() { + for (RankProfile profile : rankProfileRegistry.allRankProfiles()) { + try { + validate(profile); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("In " + search + ", " + profile, e); + } + } + } + + /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */ + private void validate(RankProfile profile) { + profile.parseExpressions(); + TypeContext context = profile.typeContext(queryProfiles); + profile.getSummaryFeatures().forEach(f -> ensureValid(f, "summary feature " + f, context)); + ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context); + ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context); + } + + private TensorType ensureValid(RankingExpression expression, String expressionDescription, TypeContext context) { + if (expression == null) return null; + return ensureValid(expression.getRoot(), expressionDescription, context); + } + + private TensorType ensureValid(ExpressionNode expression, String expressionDescription, TypeContext context) { + TensorType type; + try { + type = expression.type(context); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("The " + expressionDescription + " is invalid", e); + } + if (type == null) // Not expected to happen + throw new IllegalStateException("Could not determine the type produced by " + expressionDescription); + return type; + } + + private void ensureValidDouble(RankingExpression expression, String expressionDescription, TypeContext context) { + if (expression == null) return; + TensorType type = ensureValid(expression, expressionDescription, context); + if ( ! type.equals(TensorType.empty)) + throw new IllegalArgumentException("The " + expressionDescription + " must produce a double " + + "(a tensor with no dimensions), but produces " + type); + } + +} diff --git a/config-model/src/test/derived/rankexpression/rank-profiles.cfg b/config-model/src/test/derived/rankexpression/rank-profiles.cfg index e890b75770b..f5652c31d2a 100644 --- a/config-model/src/test/derived/rankexpression/rank-profiles.cfg +++ b/config-model/src/test/derived/rankexpression/rank-profiles.cfg @@ -24,7 +24,7 @@ rankprofile[0].fef.property[10].value "4" rankprofile[0].fef.property[11].name "vespa.dump.feature" rankprofile[0].fef.property[11].value "attribute(foo1).out" rankprofile[0].fef.property[12].name "vespa.dump.feature" -rankprofile[0].fef.property[12].value "attribute(bar1.out)" +rankprofile[0].fef.property[12].value "attribute(bar1)" rankprofile[0].fef.property[13].name "vespa.dump.feature" rankprofile[0].fef.property[13].value "attribute(foo2).out" rankprofile[0].fef.property[14].name "vespa.dump.feature" @@ -64,7 +64,7 @@ rankprofile[2].fef.property[2].value "10 + feature(arg1).out.out" rankprofile[2].fef.property[3].name "vespa.summary.feature" rankprofile[2].fef.property[3].value "attribute(foo1).out" rankprofile[2].fef.property[4].name "vespa.summary.feature" -rankprofile[2].fef.property[4].value "attribute(bar1.out)" +rankprofile[2].fef.property[4].value "attribute(bar1)" rankprofile[2].fef.property[5].name "vespa.summary.feature" rankprofile[2].fef.property[5].value "attribute(foo2).out" rankprofile[2].fef.property[6].name "vespa.summary.feature" diff --git a/config-model/src/test/derived/rankexpression/rankexpression.sd b/config-model/src/test/derived/rankexpression/rankexpression.sd index 8ed1f2bab4c..d3e0057cfe1 100644 --- a/config-model/src/test/derived/rankexpression/rankexpression.sd +++ b/config-model/src/test/derived/rankexpression/rankexpression.sd @@ -5,12 +5,10 @@ search rankexpression { field artist type string { indexing: summary | index - # index-to: artist, default } field title type string { indexing: summary | index - # index-to: title, default } field surl type string { @@ -21,6 +19,38 @@ search rankexpression { indexing: summary | attribute } + field foo1 type int { + indexing: attribute + } + + field foo2 type int { + indexing: attribute + } + + field foo3 type int { + indexing: attribute + } + + field foo4 type int { + indexing: attribute + } + + field bar1 type int { + indexing: attribute + } + + field bar2 type int { + indexing: attribute + } + + field bar3 type int { + indexing: attribute + } + + field bar4 type int { + indexing: attribute + } + } rank-profile default { @@ -33,7 +63,7 @@ search rankexpression { expression: if(3>2,4,2) rerank-count: 10 } - rank-features: attribute(foo1).out attribute(bar1.out) + rank-features: attribute(foo1).out attribute(bar1) rank-features { attribute(foo2).out attribute(bar2).out } rank-features { attribute(foo3).out attribute(bar3).out } @@ -65,7 +95,7 @@ search rankexpression { file:rankexpression } } - summary-features: attribute(foo1).out attribute(bar1.out) + summary-features: attribute(foo1).out attribute(bar1) summary-features { attribute(foo2).out attribute(bar2).out } summary-features { attribute(foo3).out attribute(bar3).out } diff --git a/config-model/src/test/derived/rankexpression/summary.cfg b/config-model/src/test/derived/rankexpression/summary.cfg index 00df2e87144..9752a9f55e3 100644 --- a/config-model/src/test/derived/rankexpression/summary.cfg +++ b/config-model/src/test/derived/rankexpression/summary.cfg @@ -15,9 +15,25 @@ classes[0].fields[5].name "summaryfeatures" classes[0].fields[5].type "featuredata" classes[0].fields[6].name "documentid" classes[0].fields[6].type "longstring" -classes[1].id 1787488393 +classes[1].id 1736696699 classes[1].name "attributeprefetch" classes[1].fields[0].name "year" +classes[].fields[].type "integer" +classes[].fields[].name "foo1" +classes[].fields[].type "integer" +classes[].fields[].name "foo2" +classes[].fields[].type "integer" +classes[].fields[].name "foo3" +classes[].fields[].type "integer" +classes[].fields[].name "foo4" +classes[].fields[].type "integer" +classes[].fields[].name "bar1" +classes[].fields[].type "integer" +classes[].fields[].name "bar2" +classes[].fields[].type "integer" +classes[].fields[].name "bar3" +classes[].fields[].type "integer" +classes[].fields[].name "bar4" classes[1].fields[0].type "integer" classes[1].fields[1].name "rankfeatures" classes[1].fields[1].type "featuredata" diff --git a/config-model/src/test/derived/rankexpression/summarymap.cfg b/config-model/src/test/derived/rankexpression/summarymap.cfg index c810f7282ba..21e6cdf346f 100644 --- a/config-model/src/test/derived/rankexpression/summarymap.cfg +++ b/config-model/src/test/derived/rankexpression/summarymap.cfg @@ -7,4 +7,28 @@ override[1].command "rankfeatures" override[1].arguments "" override[2].field "summaryfeatures" override[2].command "summaryfeatures" -override[2].arguments ""
\ No newline at end of file +override[2].arguments "" +override[].field "foo1" +override[].command "attribute" +override[].arguments "foo1" +override[].field "foo2" +override[].command "attribute" +override[].arguments "foo2" +override[].field "foo3" +override[].command "attribute" +override[].arguments "foo3" +override[].field "foo4" +override[].command "attribute" +override[].arguments "foo4" +override[].field "bar1" +override[].command "attribute" +override[].arguments "bar1" +override[].field "bar2" +override[].command "attribute" +override[].arguments "bar2" +override[].field "bar3" +override[].command "attribute" +override[].arguments "bar3" +override[].field "bar4" +override[].command "attribute" +override[].arguments "bar4"
\ No newline at end of file diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 2b231e0cda2..b6ad5372c05 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -35,7 +35,7 @@ rankprofile[3].name "profile2" rankprofile[3].fef.property[0].name "vespa.rank.firstphase" rankprofile[3].fef.property[0].value "rankingExpression(firstphase)" rankprofile[3].fef.property[1].name "rankingExpression(firstphase).rankingScript" -rankprofile[3].fef.property[1].value "reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x)" +rankprofile[3].fef.property[1].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)" rankprofile[3].fef.property[2].name "vespa.type.attribute.f2" rankprofile[3].fef.property[2].value "tensor(x[2],y[])" rankprofile[3].fef.property[3].name "vespa.type.attribute.f3" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index a6a9a98db3a..3d64f6b807e 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -28,7 +28,7 @@ search tensor { rank-profile profile2 { first-phase { - expression: matmul(attribute(f4), diag(x[2],y[2],z[3]), x) + expression: sum(matmul(attribute(f4), diag(x[2],y[2],z[3]), x)) } } diff --git a/config-model/src/test/examples/rankpropvars.sd b/config-model/src/test/examples/rankpropvars.sd index 40f9e73f35a..28959edbc09 100644 --- a/config-model/src/test/examples/rankpropvars.sd +++ b/config-model/src/test/examples/rankpropvars.sd @@ -18,8 +18,8 @@ first-phase { second-phase { expression { if (attribute(artist) == query(testvar1), - 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist), - 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) + 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist), + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) } } @@ -42,8 +42,8 @@ first-phase { second-phase { expression { if (attribute(artist) == query(testvar1), - 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist), - 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) + 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist), + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) } } } diff --git a/config-model/src/test/examples/simple.sd b/config-model/src/test/examples/simple.sd index 4fda7f5039e..96b0fa98098 100644 --- a/config-model/src/test/examples/simple.sd +++ b/config-model/src/test/examples/simple.sd @@ -116,7 +116,7 @@ search simple { first-phase { keep-rank-count:200 rank-score-drop-limit: -13.0 - expression: attribute(year) + expression: attribute(popularity) } second-phase { rerank-count: 99 diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java index 1f60ad870ec..aa01070d296 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java @@ -18,17 +18,6 @@ import static org.junit.Assert.assertFalse; public class FeatureNamesTestCase { @Test - public void testCanonicalization() { - assertFalse(FeatureNames.canonicalizeIfValid("foo").isPresent()); - assertEquals("query(bar)", FeatureNames.canonicalize("query(bar)")); - assertEquals("query(bar)", FeatureNames.canonicalize("query('bar')")); - assertEquals("constant(bar)", FeatureNames.canonicalize("constant(\"bar\")")); - assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query(ba.r)")); - assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query('ba.r')")); - assertEquals("attribute(\"ba.r\")", FeatureNames.canonicalize("attribute(\"ba.r\")")); - } - - @Test public void testArgument() { assertFalse(FeatureNames.argumentOf("foo(bar)").isPresent()); assertFalse(FeatureNames.argumentOf("foo(bar.baz)").isPresent()); @@ -42,17 +31,20 @@ public class FeatureNamesTestCase { @Test public void testConstantFeature() { - assertEquals("constant(\"foo/bar\")", FeatureNames.asConstantFeature("foo/bar")); + assertEquals("constant(\"foo/bar\")", + FeatureNames.asConstantFeature("foo/bar").toString()); } @Test public void testAttributeFeature() { - assertEquals("attribute(foo)", FeatureNames.asAttributeFeature("foo")); + assertEquals("attribute(foo)", + FeatureNames.asAttributeFeature("foo").toString()); } @Test public void testQueryFeature() { - assertEquals("query(\"foo.bar\")", FeatureNames.asQueryFeature("foo.bar")); + assertEquals("query(\"foo.bar\")", + FeatureNames.asQueryFeature("foo.bar").toString()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 442c8bd41bd..11093d9f008 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -135,13 +135,13 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { @Test public void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseException { RankProfileRegistry registry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(registry); + SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes()); builder.importString("search test {\n" + " document test { } \n" + " rank-profile p1 {}\n" + " rank-profile p2 {}\n" + "}"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search search = builder.getSearch(); assertEquals(4, registry.allRankProfiles().size()); @@ -151,7 +151,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search); } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -164,7 +164,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(numeric)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index e94880e61c7..82b9f5ac043 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -207,6 +207,9 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.importString( "search test {\n" + " document test { \n" + + " field rating_yelp type int {" + + " indexing: attribute" + + " }" + " }\n" + " \n" + " rank-profile test {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 5100ac15c40..ed1b00e2875 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -2,7 +2,10 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; +import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; @@ -149,11 +152,12 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase censorBindingHash(testRankProperties.get(4).toString())); } - @Test public void testNeuralNetworkSetup() throws ParseException { + // Note: the type assigned to query profile and constant tensors here is not the correct type RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])"); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -176,13 +180,28 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase " expression: sum(final_layer)\n" + " }\n" + " }\n" + - "\n" + + " constant W_hidden {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_input {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant W_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString()); @@ -198,6 +217,17 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase testRankProperties.get(5).toString()); } + private QueryProfileRegistry queryProfileWith(String field, String type) { + QueryProfileType queryProfileType = new QueryProfileType("root"); + queryProfileType.addField(new FieldDescription(field, type)); + QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry(); + queryProfileRegistry.getTypeRegistry().register(queryProfileType); + QueryProfile profile = new QueryProfile("default"); + profile.setType(queryProfileType); + queryProfileRegistry.register(profile); + return queryProfileRegistry; + } + private String censorBindingHash(String s) { StringBuilder b = new StringBuilder(); boolean areInHash = false; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 800697b3430..0ce6129ef7f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -38,7 +38,8 @@ class RankProfileSearchFixture { RankProfileSearchFixture(ApplicationPackage applicationpackage, QueryProfileRegistry queryProfileRegistry, String rankProfiles, String constant, String field) throws ParseException { - SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, new QueryProfileRegistry()); + this.queryProfileRegistry = queryProfileRegistry; + SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, queryProfileRegistry); String sdContent = "search test {\n" + " " + (constant != null ? constant : "") + "\n" + " document test {\n" + @@ -50,7 +51,6 @@ class RankProfileSearchFixture { builder.importString(sdContent); builder.build(); search = builder.getSearch(); - this.queryProfileRegistry = queryProfileRegistry; } public void assertFirstPhaseExpression(String expExpression, String rankProfile) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java new file mode 100644 index 00000000000..056d2ad534b --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java @@ -0,0 +1,192 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import java.util.Map; +import java.util.stream.Collectors; + +import static com.yahoo.config.model.test.TestUtil.joinLines; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class RankingExpressionTypeValidatorTestCase { + + @Test + public void tensorFirstPhaseMustProduceDouble() throws Exception { + try { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + builder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void tensorSecondPhaseMustProduceDouble() throws Exception { + try { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(attribute(a))", + " }", + " second-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + builder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception { + try { + SearchBuilder searchBuilder = new SearchBuilder(); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(if(1>0, attribute(a), attribute(b)))", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testMacroInvocationTypes() throws Exception { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " macro macro1(attribute_to_use) {", + " expression: attribute(attribute_to_use)", + " }", + " summary-features {", + " macro1(a)", + " macro1(b)", + " }", + " }", + "}" + )); + builder.build(); + RankProfile profile = + builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile"); + assertEquals(TensorType.fromSpec("tensor(x[],y[])"), + summaryFeatures(profile).get("macro1(a)").type(profile.typeContext(builder.getQueryProfileRegistry()))); + assertEquals(TensorType.fromSpec("tensor(z[10])"), + summaryFeatures(profile).get("macro1(b)").type(profile.typeContext(builder.getQueryProfileRegistry()))); + } + + @Test + public void testTensorMacroInvocationTypes_Nested() throws Exception { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " macro return_a() {", + " expression: return_first(attribute(a), attribute(b))", + " }", + " macro return_b() {", + " expression: return_second(attribute(a), attribute(b))", + " }", + " macro return_first(e1, e2) {", + " expression: e1", + " }", + " macro return_second(e1, e2) {", + " expression: return_first(e2, e1)", + " }", + " summary-features {", + " return_a", + " return_b", + " }", + " }", + "}" + )); + builder.build(); + RankProfile profile = + builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile"); + assertEquals(TensorType.fromSpec("tensor(x[],y[])"), + summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry()))); + assertEquals(TensorType.fromSpec("tensor(z[10])"), + summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry()))); + } + + private Map<String, ReferenceNode> summaryFeatures(RankProfile profile) { + return profile.getSummaryFeatures().stream().collect(Collectors.toMap(f -> f.toString(), f -> f)); + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 4693ac5cf4d..96795d2b08f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -73,7 +73,7 @@ public class RankingExpressionWithTensorFlowTestCase { public void testTensorFlowReferenceWithQueryFeature() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" + + " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -107,7 +107,7 @@ public class RankingExpressionWithTensorFlowTestCase { public void testTensorFlowReferenceWithFeatureCombination() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" + + " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index b001db69768..054c9220225 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,98 +17,129 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; import java.util.List; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class TensorTransformTestCase extends SearchDefinitionTestCase { @Test public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { - assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)"); - assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)"); - assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))"); - assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))"); - assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))"); - assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)"); - assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)"); - assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)"); - assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)"); - assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)"); + assertTransformedExpression("max(1.0,2.0)", + "max(1.0,2.0)"); + assertTransformedExpression("min(attribute(double_field),x)", + "min(attribute(double_field),x)"); + assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))", + "max(attribute(double_field),attribute(double_array_field))"); + assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))", + "min(attribute(tensor_field_1),attribute(double_field))"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)", + "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)"); + assertTransformedExpression("min(constant(test_constant_tensor),1.0)", + "min(test_constant_tensor,1.0)"); + assertTransformedExpression("max(constant(base_constant_tensor),1.0)", + "max(base_constant_tensor,1.0)"); + assertTransformedExpression("min(constant(file_constant_tensor),1.0)", + "min(constant(file_constant_tensor),1.0)"); + assertTransformedExpression("max(query(q),1.0)", + "max(query(q),1.0)"); + assertTransformedExpression("max(query(n),1.0)", + "max(query(n),1.0)"); } @Test public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { - assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)"); - assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error. - assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)"); + assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)", + "max(attribute(tensor_field_1),x)"); + assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", + "1 + max(attribute(tensor_field_1),x)"); + assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)", + "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", + "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); + assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", + "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)"); + assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)", + "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error. + assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)", + "max(max(attribute(tensor_field_2),x),y)"); } @Test public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { - assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)"); - assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)"); - assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)"); + assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)", + "max(test_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)", + "max(base_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)", + "min(constant(file_constant_tensor),x)"); } @Test public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { - assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)"); - assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); - assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); - assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) - assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)"); + assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)", + "min(attribute(double_field) + attribute(tensor_field_1),x)"); + assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)", + "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)"); + assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)", + "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)"); + assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", + "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) + assertTransformedExpression("reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)", + "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); } @Test public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)"); - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)", + "max(tensorFromLabels(attribute(double_array_field)),double_array_field)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)", + "max(tensorFromLabels(attribute(double_array_field),x),x)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)", + "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)", + "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)"); } @Test public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { - assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)"); - assertContainsExpression("max(query(n),x)", "max(query(n),x)"); + assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)"); + assertTransformedExpression("max(query(n),x)", "max(query(n),x)"); } @Test public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException { - assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)"); - assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)"); - assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)"); - assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)", + "max(returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)", + "max(wraps_returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)", + "max(tensor_inheriting,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)", + "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)"); } - private void assertContainsExpression(String expr, String transformedExpression) throws ParseException { - assertTrue("Expected expression '" + transformedExpression + "' found", - containsExpression(expr, transformedExpression)); - } - - private boolean containsExpression(String expr, String transformedExpression) throws ParseException { - for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) { + private void assertTransformedExpression(String expected, String original) throws ParseException { + for (Pair<String, String> rankPropertyExpression : buildSearch(original)) { String rankProperty = rankPropertyExpression.getFirst(); if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) { String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ","")); - return rankExpression.equals(transformedExpression); + assertEquals(expected, rankExpression); + return; } } - return false; + fail("No 'rankingExpression(firstphase).rankingScript' property produced"); } private List<Pair<String, String>> buildSearch(String expression) throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = setupQueryProfileTypes(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -167,16 +198,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { " }\n" + " }\n" + "}\n"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); return testRankProperties; } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -185,7 +216,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(n)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private String censorBindingHash(String s) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 2e2858da238..262aba89f27 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.text.Utf8; @@ -11,9 +12,9 @@ import java.security.NoSuchAlgorithmException; import java.util.*; /** - * <p>A function defined by a ranking expression</p> + * A function defined by a ranking expression * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen * @author bratseth */ public class ExpressionFunction { @@ -23,7 +24,7 @@ public class ExpressionFunction { private final RankingExpression body; /** - * <p>Constructs a new function</p> + * Constructs a new function * * @param name the name of this function * @param arguments its argument names @@ -43,28 +44,27 @@ public class ExpressionFunction { public RankingExpression getBody() { return body; } /** - * <p>Create and return an instance of this function based on the given - * arguments. If function calls are nested, this call might produce - * additional scripts.</p> + * Creates and returns an instance of this function based on the given + * arguments. If function calls are nested, this call may produce + * additional functions. * * @param context the context used to expand this - * @param arguments the arguments to instantiate on. + * @param argumentValues the arguments to instantiate on. * @param path the expansion path leading to this. * @return the script function instance created. */ - public Instance expand(SerializationContext context, List<ExpressionNode> arguments, Deque<String> path) { + public Instance expand(SerializationContext context, List<ExpressionNode> argumentValues, Deque<String> path) { Map<String, String> argumentBindings = new HashMap<>(); - for (int i = 0; i < this.arguments.size() && i < arguments.size(); ++i) { - argumentBindings.put(this.arguments.get(i), arguments.get(i).toString(context, path, null)); + for (int i = 0; i < arguments.size() && i < arguments.size(); ++i) { + argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(context, path, null)); } - return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.createBinding(argumentBindings), path, null)); + return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.withBindings(argumentBindings), path, null)); } /** * Returns a symbolic string that represents this function with a given * list of arguments. The arguments are mangled by hashing the string - * representation of the argument expressions, so we might need to revisit - * this if we start seeing collisions. + * representation of the argument expressions. * * @param argumentBindings the bound arguments to include in the symbolic name. * @return the symbolic name for an instance of this function @@ -85,8 +85,8 @@ public class ExpressionFunction { /** - * <p>Returns a more unique hash code than what Java's own {@link - * String#hashCode()} method would produce.</p> + * Returns a more unique hash code than what Java's own {@link + * String#hashCode()} method would produce. * * @param str The string to hash. * @return A 64 bit long hash code. @@ -136,4 +136,5 @@ public class ExpressionFunction { } } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java index 49466f1974d..f0532d9d433 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java @@ -91,8 +91,8 @@ public class FeatureList implements Iterable<ReferenceNode> { /** * Returns the feature at the given index. * - * @param i The index of the feature to return. - * @return The featuer at the given index. + * @param i the index of the feature to return. + * @return the feature at the given index. */ public ReferenceNode get(int i) { return features.get(i); @@ -137,4 +137,5 @@ public class FeatureList implements Iterable<ReferenceNode> { public Iterator<ReferenceNode> iterator() { return features.iterator(); } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index c8d90e8c4e8..6b2422d7cb2 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -244,10 +244,6 @@ public class RankingExpression implements Serializable { * @return a list of named rank properties required to implement this expression. */ public Map<String, String> getRankProperties(List<ExpressionFunction> macros) { - Map<String, ExpressionFunction> arg = new HashMap<>(); - for (ExpressionFunction function : macros) { - arg.put(function.getName(), function); - } Deque<String> path = new LinkedList<>(); SerializationContext context = new SerializationContext(macros); String serializedRoot = root.toString(context, path, null); @@ -272,7 +268,7 @@ public class RankingExpression implements Serializable { * * @throws IllegalArgumentException if this expression is not type correct in this context */ - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return root.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java new file mode 100644 index 00000000000..6277721e8f5 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java @@ -0,0 +1,121 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression; + +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Deque; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A reference to a feature, function, or value in ranking expressions + * + * @author bratseth + */ +public class Reference extends TypeContext.Name { + + private final String name; + private final Arguments arguments; + + /** + * The output, or null if none + */ + private final String output; + + public Reference(String name, Arguments arguments, String output) { + super(name); + Objects.requireNonNull(name, "name cannot be null"); + Objects.requireNonNull(arguments, "arguments cannot be null"); + this.name = name; + this.arguments = arguments; + this.output = output; + } + + public String name() { return name; } + + public Arguments arguments() { return arguments; } + + public String output() { return output; } + + /** + * Creates a reference to a simple feature consisting of a name and a single argument + */ + public static Reference simple(String name, String argumentValue) { + return new Reference(name, + new Arguments(new ReferenceNode(argumentValue)), + null); + } + + /** + * Returns the given simple feature as a reference, or empty if it is not a valid simple + * feature string on the form name(argument). + */ + public static Optional<Reference> simple(String feature) { + int startParenthesis = feature.indexOf('('); + if (startParenthesis < 0) + return Optional.empty(); + int endParenthesis = feature.lastIndexOf(')'); + String featureName = feature.substring(0, startParenthesis); + if (startParenthesis < 1 || endParenthesis < startParenthesis) return Optional.empty(); + String argument = feature.substring(startParenthesis + 1, endParenthesis); + if (argument.startsWith("'") || argument.startsWith("\"")) + argument = argument.substring(1); + if (argument.endsWith("'") || argument.endsWith("\"")) + argument = argument.substring(0, argument.length() - 1); + return Optional.of(simple(featureName, argument)); + } + + /** + * Returns whether this is a simple identifier - no arguments or output + */ + public boolean isIdentifier() { + return this.arguments.expressions().size() == 0 && output == null; + } + + public Reference withArguments(Arguments arguments) { + return new Reference(name, arguments, output); + } + + public Reference withOutput(String output) { + return new Reference(name, arguments, output); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (!(o instanceof Reference)) return false; + Reference other = (Reference) o; + if (!Objects.equals(other.name, this.name)) return false; + if (!Objects.equals(other.arguments, this.arguments)) return false; + if (!Objects.equals(other.output, this.output)) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(name, arguments, output); + } + + @Override + public String toString() { + return toString(new SerializationContext(), null, null); + } + + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + StringBuilder b = new StringBuilder(name); + if (arguments != null && arguments.expressions().size() > 0) + b.append("(").append(arguments.expressions().stream() + .map(node -> node.toString(context, path, parent)) + .collect(Collectors.joining(","))).append(")"); + if (output != null) + b.append(".").append(output); + return b.toString(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index 5f8daa69ecf..ee5952d9aea 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -82,8 +83,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { } @Override - public TensorType getType(String name) { - Integer index = nameToIndex().get(name); + public TensorType getType(Reference reference) { + Integer index = nameToIndex().get(reference.toString()); if (index == null) return null; return values[index].type(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 861f9565d66..4e046df11ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -1,9 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -14,7 +16,7 @@ import java.util.stream.Collectors; * * @author bratseth */ -public abstract class Context implements EvaluationContext { +public abstract class Context implements EvaluationContext<Reference> { /** * Returns the value of a simple variable name. @@ -24,6 +26,11 @@ public abstract class Context implements EvaluationContext { */ public abstract Value get(String name); + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + /** Returns a variable as a tensor */ @Override public Tensor getTensor(String name) { return get(name).asTensor(); } @@ -46,6 +53,7 @@ public abstract class Context implements EvaluationContext { * calculation to output several), or null to output the * "main" (or only) value. */ + // TODO: Remove/change to use reference? public Value get(String name, Arguments arguments, String output) { if (arguments != null && arguments.expressions().size() > 0) name = name + "(" + arguments.expressions().stream().map(ExpressionNode::toString).collect(Collectors.joining(",")) + ")"; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index 0625e8506cc..0004036da4b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; /** @@ -68,7 +69,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { } @Override - public TensorType getType(String name) { return TensorType.empty; } + public TensorType getType(Reference reference) { + return TensorType.empty; // Double only + } /** Perform a slow lookup by name */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index a81d0c89f8f..4ef24d60bba 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import java.util.Collections; @@ -15,7 +16,7 @@ import java.util.Set; */ public class MapContext extends Context { - private Map<String, Value> bindings = new HashMap<>(); + private Map<String, Value> bindings = new HashMap<>(); // TODO: Change String to Reference private boolean frozen = false; @@ -42,8 +43,8 @@ public class MapContext extends Context { /** Returns the type of the given value key, or null if it is not bound. */ @Override - public TensorType getType(String key) { - Value value = bindings.get(key); + public TensorType getType(Reference key) { + Value value = bindings.get(key.toString()); if (value == null) return null; return value.type(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java new file mode 100644 index 00000000000..2a42e2d92f7 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A context which only contains type information. + * + * @author bratseth + */ +public class MapTypeContext implements TypeContext<Reference> { + + private final Map<Reference, TensorType> featureTypes = new HashMap<>(); + + public void setType(Reference reference, TensorType type) { + featureTypes.put(reference, type); + } + + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + + @Override + public TensorType getType(Reference reference) { + return featureTypes.get(reference); + } + + /** Returns an unmodifiable map of the bindings in this */ + public Map<Reference, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java deleted file mode 100644 index ff2088263d8..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -/** - * A context which only contains type information. - * - * @author bratseth - */ -public class TypeMapContext implements TypeContext { - - private final Map<String, TensorType> featureTypes = new HashMap<>(); - - public void setType(String name, TensorType type) { - featureTypes.put(name, type); - } - - @Override - public TensorType getType(String name) { - return featureTypes.get(name); - } - - /** Returns an unmodifiable map of the bindings in this */ - public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java index 8ee4cdbf297..649c70122f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -26,7 +27,7 @@ public class GBDTForestNode extends ExpressionNode { } @Override - public final TensorType type(TypeContext context) { return TensorType.empty; } + public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index aac635b2545..53a286f09f6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -51,7 +52,7 @@ public final class GBDTNode extends ExpressionNode { public final double[] values() { return values; } @Override - public final TensorType type(TypeContext context) { return TensorType.empty; } + public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java index fb9a7cb9ad7..d3a12d0f312 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java @@ -13,7 +13,7 @@ import java.util.List; /** * A set of argument expressions to a function or feature. - * This is immutable. + * This is a value object. * * @author bratseth */ @@ -22,7 +22,11 @@ public final class Arguments implements Serializable { private final ImmutableList<ExpressionNode> expressions; public Arguments() { - this(null); + this(ImmutableList.of()); + } + + public Arguments(ExpressionNode singleArgument) { + this(ImmutableList.of(singleArgument)); } public Arguments(List<? extends ExpressionNode> expressions) { @@ -38,9 +42,12 @@ public final class Arguments implements Serializable { this.expressions = b.build(); } - /** Returns an unmodifiable list of the expressions in this */ + /** Returns an unmodifiable list of the expressions in this, never null */ public List<ExpressionNode> expressions() { return expressions; } + /** Returns the number of arguments in this */ + public int size() { return expressions.size(); } + /** Evaluate all arguments in this */ public Value[] evaluate(Context context) { Value[] values=new Value[expressions.size()]; @@ -62,8 +69,9 @@ public final class Arguments implements Serializable { } @Override - public boolean equals(Object rhs) { - return rhs instanceof Arguments && expressions.equals(((Arguments)rhs).expressions); + public boolean equals(Object other) { + if (other == this) return true; + return other instanceof Arguments && expressions.equals(((Arguments)other).expressions); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index fc6428a4c33..49c49bed9bd 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -80,7 +81,7 @@ public final class ArithmeticNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { // Compute type using tensor types as arithmetic operators are supported on tensors // and is correct also in the special case of doubles. // As all our functions are type-commutative, we don't need to take operator precedence into account diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java index 1d7d9b1ecda..cd4ddbcae55 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java @@ -5,7 +5,6 @@ package com.yahoo.searchlib.rankingexpression.rule; * A node which produces a boolean value when evaluated. * * @author bratseth - * @since 5.1.21 */ public abstract class BooleanNode extends CompositeNode { } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 7601c0e6180..eb328486045 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -49,7 +50,7 @@ public class ComparisonNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; // by definition } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java index 1ea8d03f0eb..3ddd7223349 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -49,7 +50,7 @@ public final class ConstantNode extends ExpressionNode { } @Override - public TensorType type(TypeContext context) { return value.type(); } + public TensorType type(TypeContext<Reference> context) { return value.type(); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index fd9fab99db8..47c2897e4a4 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -50,7 +51,7 @@ public final class EmbracedNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index 477f4db4981..6bb163590de 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -48,7 +49,7 @@ public abstract class ExpressionNode implements Serializable { * @param context the variable type bindings to use for this evaluation * @throws IllegalArgumentException if there are variables which are not bound in the given map */ - public abstract TensorType type(TypeContext context); + public abstract TensorType type(TypeContext<Reference> context); /** * Returns the value of evaluating this expression over the given context. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index 79515229019..1da2210a39c 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -67,7 +68,7 @@ public final class FunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { if (arguments.expressions().size() == 0) return TensorType.empty; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java new file mode 100644 index 00000000000..ed1e2838717 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.collect.ImmutableMap; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * The context of a function invocation. + * + * @author bratseth + */ +public class FunctionReferenceContext { + + /** Expression functions indexed by name */ + private final ImmutableMap<String, ExpressionFunction> functions; + + /** Mapping from argument names to the expressions they resolve to */ + // TODO: Make private + public final Map<String, String> bindings = new HashMap<>(); + + /** Create a context for a single serialization task */ + public FunctionReferenceContext() { + this(Collections.emptyList()); + } + + /** Create a context for a single serialization task */ + public FunctionReferenceContext(Collection<ExpressionFunction> functions) { + this(toMap(functions), Collections.emptyMap()); + } + + public FunctionReferenceContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) { + this(toMap(functions), bindings); + } + + /** Create a context for a single serialization task */ + public FunctionReferenceContext(Map<String, ExpressionFunction> functions) { + this(functions.values()); + } + + /** Create a context for a single serialization task */ + public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) { + this.functions = ImmutableMap.copyOf(functions); + if (bindings != null) + this.bindings.putAll(bindings); + } + + private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) { + ImmutableMap.Builder<String,ExpressionFunction> mapBuilder = new ImmutableMap.Builder<>(); + for (ExpressionFunction function : list) + mapBuilder.put(function.getName(), function); + return mapBuilder.build(); + } + + /** + * Returns a function or null if it isn't defined in this context + */ + public ExpressionFunction getFunction(String name) { return functions.get(name); } + + protected Map<String, ExpressionFunction> functions() { return functions; } + + /** Returns the resolution of an argument, or null if it isn't defined in this context */ + public String getBinding(String name) { return bindings.get(name); } + + /** Returns a new context with the bindings replaced by the given bindings */ + public FunctionReferenceContext withBindings(Map<String, String> bindings) { + return new FunctionReferenceContext(this.functions, bindings); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index e42884ecc05..c87eb0ace39 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -48,7 +49,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { return type; } + public TensorType type(TypeContext<Reference> context) { return type; } /** Evaluate this in a context which must have the arguments bound */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 66b250736e8..ee4edac4941 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -75,7 +76,7 @@ public final class IfNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { TensorType trueType = trueExpression.type(context); TensorType falseType = falseExpression.type(context); return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() -> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index da946228291..61086f8182a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -57,7 +58,7 @@ public class LambdaFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; // by definition - no nested lambdas } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index f55ed59b65c..f1adf331630 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -14,6 +15,7 @@ import java.util.Deque; * * @author Simon Thoresen */ +// TODO: This is achieved by ReferenceNode in almost all cases - remove this public final class NameNode extends ExpressionNode { private final String name; @@ -32,7 +34,7 @@ public final class NameNode extends ExpressionNode { } @Override - public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); } + public TensorType type(TypeContext<Reference> context) { throw new RuntimeException("Named nodes can not have a type"); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java index 9cbe5f98c72..fcc03dc4862 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -38,7 +39,7 @@ public class NegativeNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java index e7041600635..a539f496ff5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -38,7 +39,7 @@ public class NotNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 05a6773c5cb..78f53b1593d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -13,114 +14,102 @@ import java.util.Deque; import java.util.List; /** - * A node referring either to a value in the context or to another named ranking expression. + * A node referring either to a value in the context or to a named ranking expression (function aka macro). * * @author simon * @author bratseth */ public final class ReferenceNode extends CompositeNode { - private final String name, output; - - private final Arguments arguments; + private final Reference reference; + /* Creates a node with a simple identifier reference */ public ReferenceNode(String name) { this(name, null, null); } public ReferenceNode(String name, List<? extends ExpressionNode> arguments, String output) { - this.name = name; - this.arguments = arguments != null ? new Arguments(arguments) : new Arguments(); - this.output = output; + this.reference = new Reference(name, + arguments != null ? new Arguments(arguments) : new Arguments(), + output); + } + + public ReferenceNode(Reference reference) { + this.reference = reference; } public String getName() { - return name; + return reference.name(); } /** Returns the arguments, never null */ - public Arguments getArguments() { return arguments; } + public Arguments getArguments() { return reference.arguments(); } /** Returns a copy of this where the arguments are replaced by the given arguments */ public ReferenceNode setArguments(List<ExpressionNode> arguments) { - return new ReferenceNode(name, arguments, output); + return new ReferenceNode(reference.withArguments(new Arguments(arguments))); } /** Returns the specific output this references, or null if none specified */ - public String getOutput() { return output; } + public String getOutput() { return reference.output(); } /** Returns a copy of this node with a modified output */ public ReferenceNode setOutput(String output) { - return new ReferenceNode(name, arguments.expressions(), output); + return new ReferenceNode(reference.withOutput(output)); } /** Returns an empty list as this has no children */ @Override - public List<ExpressionNode> children() { return arguments.expressions(); } + public List<ExpressionNode> children() { return reference.arguments().expressions(); } @Override public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - if (path == null) - path = new ArrayDeque<>(); - String myName = this.name; - String myOutput = this.output; - List<ExpressionNode> myArguments = this.arguments.expressions(); - - String resolvedArgument = context.getBinding(myName); - if (resolvedArgument != null && this.arguments.expressions().size() == 0 && myOutput == null) { - // Replace this whole node with the value of the argument value that it maps to - myName = resolvedArgument; - myArguments = null; - myOutput = null; - } else if (context.getFunction(myName) != null) { - // Replace by the referenced expression - ExpressionFunction function = context.getFunction(myName); - if (function != null && myArguments != null && function.arguments().size() == myArguments.size() && myOutput == null) { - String myPath = name + this.arguments.expressions(); - if (path.contains(myPath)) { - throw new IllegalStateException("Cycle in ranking expression function: " + path); - } - path.addLast(myPath); - ExpressionFunction.Instance instance = function.expand(context, myArguments, path); - path.removeLast(); - context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); - myName = "rankingExpression(" + instance.getName() + ")"; - myArguments = null; - myOutput = null; - } + if (reference.isIdentifier() && context.getBinding(getName()) != null) { + // a bound identifier: replace by the value it is bound to + return context.getBinding(getName()); } - // Always print the same way, the magic is already done. - StringBuilder ret = new StringBuilder(myName); - if (myArguments != null && myArguments.size() > 0) { - ret.append("("); - for (int i = 0; i < myArguments.size(); ++i) { - ret.append(myArguments.get(i).toString(context, path, this)); - if (i < myArguments.size() - 1) { - ret.append(","); - } - } - ret.append(")"); + + ExpressionFunction function = context.getFunction(getName()); + if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) { + // a function reference: replace by the referenced function wrapped in rankingExpression + if (path == null) + path = new ArrayDeque<>(); + String myPath = getName() + getArguments().expressions(); + if (path.contains(myPath)) + throw new IllegalStateException("Cycle in ranking expression function: " + path); + path.addLast(myPath); + ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); + path.removeLast(); + context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); + return "rankingExpression(" + instance.getName() + ")"; } - ret.append(myOutput != null ? "." + myOutput : ""); - return ret.toString(); + + // not resolved in this context: output as-is + return reference.toString(context, path, parent); } + /** Returns the reference of this node */ + public Reference reference() { return reference; } + @Override - public TensorType type(TypeContext context) { - // Don't support outputs of different type, for simplicity - return context.getType(toString()); + public TensorType type(TypeContext<Reference> context) { + TensorType type = context.getType(reference); + if (type == null) + throw new IllegalArgumentException("Unknown feature '" + toString() + "'"); + return type; } @Override public Value evaluate(Context context) { - if (arguments.expressions().isEmpty() && output == null) - return context.get(name); - return context.get(name, arguments, output); + // TODO: Context should accept a Reference instead. + if (reference.isIdentifier()) + return context.get(reference.name()); + return context.get(getName(), getArguments(), getOutput()); } @Override public CompositeNode setChildren(List<ExpressionNode> newChildren) { - return new ReferenceNode(name, newChildren, output); + return setArguments(newChildren); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index ba765d07094..796c13a8669 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -16,17 +16,11 @@ import java.util.Map; * * @author bratseth */ -public class SerializationContext { +public class SerializationContext extends FunctionReferenceContext { - /** Expression functions indexed by name */ - private final ImmutableMap<String, ExpressionFunction> functions; - - /** A cache of already serialized expressions indexed by name */ + /** Serialized form of functions indexed by name */ private final Map<String, String> serializedFunctions; - /** Mapping from argument names to the expressions they resolve to */ - public final Map<String, String> bindings = new HashMap<>(); - /** Create a context for a single serialization task */ public SerializationContext() { this(Collections.emptyList()); @@ -77,17 +71,10 @@ public class SerializationContext { */ public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings, Map<String, String> serializedFunctions) { - this.functions = functions; + super(functions, bindings); this.serializedFunctions = serializedFunctions; - if (bindings != null) - this.bindings.putAll(bindings); } - /** - * Returns a function or null if it isn't defined in this context - */ - public ExpressionFunction getFunction(String name) { return functions.get(name); } - /** Adds the serialization of a function */ public void addFunctionSerialization(String name, String expressionString) { serializedFunctions.put(name, expressionString); @@ -98,17 +85,9 @@ public class SerializationContext { return serializedFunctions.get(name); } - /** - * Returns the resolution of an argument, or null if it isn't defined in this context - */ - public String getBinding(String name) { return bindings.get(name); } - - /** - * Returns a new context which shares the functions and serialized function map with this but has different - * arguments. - */ - public SerializationContext createBinding(Map<String, String> arguments) { - return new SerializationContext(this.functions, arguments, this.serializedFunctions); + @Override + public SerializationContext withBindings(Map<String, String> bindings) { + return new SerializationContext(functions().values(), bindings, this.serializedFunctions); } public Map<String, String> serializedFunctions() { return serializedFunctions; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index a7b82f4753f..cb31219579a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -60,7 +61,7 @@ public class SetMembershipNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ec6af4bb413..6c9b6bb4a98 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -64,7 +65,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { return function.type(context); } + public TensorType type(TypeContext<Reference> context) { return function.type(context); } @Override public Value evaluate(Context context) { @@ -111,12 +112,13 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { - return expression.type(context); + @SuppressWarnings("unchecked") // Generics awkwardness + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return expression.type((TypeContext<Reference>)context); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return expression.evaluate((Context)context).asTensor(); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index e9030cf5852..f2122bb5da9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -378,8 +378,13 @@ public class EvaluationTestCase { private static class StructuredTestContext extends MapContext { @Override + public Value get(String feature) { + throw new RuntimeException("Called simple get for feature " + feature); + } + + @Override public Value get(String name, Arguments arguments, String output) { - if (!name.equals("average")) { + if ( ! name.equals("average")) { throw new IllegalArgumentException("Unknown operation '" + name + "'"); } if (arguments.expressions().size() != 2) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java index c882c887c8d..a08d510eec4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; @@ -18,12 +19,17 @@ public class TypeResolutionTestCase { @Test public void testTypeResolution() { - TypeMapContext context = new TypeMapContext(); - context.setType("query(x1)", TensorType.fromSpec("tensor(x[])")); - context.setType("query(x2)", TensorType.fromSpec("tensor(x[10])")); - context.setType("query(y1)", TensorType.fromSpec("tensor(y[])")); - context.setType("query(xy1)", TensorType.fromSpec("tensor(x[10],y[])")); - context.setType("query(xy2)", TensorType.fromSpec("tensor(x[],y[10])")); + MapTypeContext context = new MapTypeContext(); + context.setType(Reference.simple("query", "x1"), + TensorType.fromSpec("tensor(x[])")); + context.setType(Reference.simple("query", "x2"), + TensorType.fromSpec("tensor(x[10])")); + context.setType(Reference.simple("query", "y1"), + TensorType.fromSpec("tensor(y[])")); + context.setType(Reference.simple("query", "xy1"), + TensorType.fromSpec("tensor(x[10],y[])")); + context.setType(Reference.simple("query", "xy2"), + TensorType.fromSpec("tensor(x[],y[10])")); assertType("tensor(x[])", "query(x1)", context); assertType("tensor(x[])", "if (1>0, query(x1), query(x2))", context); @@ -31,7 +37,7 @@ public class TypeResolutionTestCase { assertIncompatibleType("if (1>0, query(x1), query(y1))", context); } - private void assertType(String type, String expression, TypeContext context) { + private void assertType(String type, String expression, TypeContext<Reference> context) { try { assertEquals(TensorType.fromSpec(type), new RankingExpression(expression).type(context)); } @@ -40,7 +46,7 @@ public class TypeResolutionTestCase { } } - private void assertIncompatibleType(String expression, TypeContext context) { + private void assertIncompatibleType(String expression, TypeContext<Reference> context) { try { new RankingExpression(expression).type(context); fail("Expected type incompatibility exception"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java index 867331e99ce..303135888d8 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java @@ -9,13 +9,13 @@ import java.util.Collections; import static org.junit.Assert.*; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public class ArgumentsTestCase { @Test public void requireThatAccessorsWork() { - Arguments args = new Arguments(null); + Arguments args = new Arguments(); assertTrue(args.expressions().isEmpty()); args = new Arguments(Collections.<ExpressionNode>emptyList()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 3fb94f1251b..8a969180113 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor; * @author bratseth */ @Beta -public interface EvaluationContext extends TypeContext { +public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> { /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index 9fe6b7d053f..b9394da31e3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -11,17 +11,20 @@ import java.util.HashMap; * @author bratseth */ @Beta -public class MapEvaluationContext implements EvaluationContext { +public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> { private final java.util.Map<String, Tensor> bindings = new HashMap<>(); - static MapEvaluationContext empty() { return new MapEvaluationContext(); } - public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override public TensorType getType(String name) { - Tensor tensor = bindings.get(name); + return getType(new Name(name)); + } + + @Override + public TensorType getType(Name name) { + Tensor tensor = bindings.get(name.toString()); if (tensor == null) return null; return tensor.type(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index 760a225efdf..ff2e6318b37 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType; * * @author bratseth */ -public interface TypeContext { +public interface TypeContext<NAMETYPE extends TypeContext.Name> { /** * Returns the type of the tensor with this name. @@ -16,6 +16,39 @@ public interface TypeContext { * @return returns the type of the tensor which will be returned by calling getTensor(name) * or null if getTensor will return null. */ + TensorType getType(NAMETYPE name); + + /** + * Returns the type of the tensor with this name by converting from a string name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ TensorType getType(String name); + /** A name which is just a string. Names are value objects. */ + class Name { + + private final String name; + + public Name(String name) { + this.name = name; + } + + @Override + public String toString() { return name; } + + @Override + public int hashCode() { return name.hashCode(); } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + if ( ! (other instanceof Name)) return false; + return ((Name)other).name.equals(this.name); + } + + } + + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 34beb465d4c..acb2363cba4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -44,7 +44,7 @@ public class VariableTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { TensorType givenType = context.getType(name); if (givenType == null) return null; verifyType(givenType); @@ -52,7 +52,7 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = context.getTensor(name); if (tensor == null) return null; verifyType(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 2109b730e1a..bfc0938abcc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -18,10 +18,14 @@ public abstract class CompositeTensorFunction extends TensorFunction { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final TensorType type(TypeContext context) { return toPrimitive().type(context); } + public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return toPrimitive().type(context); + } /** Evaluates this by first converting it to a primitive function */ @Override - public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } + public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index c77ed1c0526..a073053bec8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -60,7 +60,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argumentA.type(context), argumentB.type(context)); } @@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); a = ensureIndexedDimension(dimension, a); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 50b479da168..a43de297b9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return constant.type(); } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } @Override - public Tensor evaluate(EvaluationContext context) { return constant; } + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override public String toString(ToStringContext context) { return constant.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e70d1de3db7..edfa8253eb9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return type; } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); for (int i = 0; i < indexes.size(); i++) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 7812c985091..17e1c103ea3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 53504868ff2..4a338e5501e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return argument.type(context); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 76a938b9fe2..e045effbe7e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -101,11 +101,12 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } private TensorType type(TensorType argumentType) { + if (dimensions.isEmpty()) return TensorType.empty; // means reduce all TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep @@ -114,7 +115,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index de3d2be265a..af4492ca1e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); TensorType renamedType = type(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 78ab09c7820..e805e9d87bb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -43,14 +43,14 @@ public abstract class TensorFunction { * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract Tensor evaluate(EvaluationContext context); + public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract TensorType type(TypeContext context); + public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context); /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } @@ -58,7 +58,7 @@ public abstract class TensorFunction { /** * Return a string representation of this context. * - * @param context a context which must be passed to all nexted functions when requesting the string value + * @param context a context which must be passed to all nested functions when requesting the string value */ public abstract String toString(ToStringContext context); |