summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2018-02-20 16:58:07 +0100
committerGitHub <noreply@github.com>2018-02-20 16:58:07 +0100
commit7cbcd92168f36a63f0dade4acc5683e134e9ac48 (patch)
treea15aaa05bb4d0592d655cd5c8c57fe02f258a39e
parent3cb51aa803fef2ab0d622768ff623a80691d6811 (diff)
parentbf9358e1c983ca3b2c4f9630873ed4e53634236f (diff)
Merge pull request #5065 from vespa-engine/bratseth/typecheck-all-2
Bratseth/typecheck all 2
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java110
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java156
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java55
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java26
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java32
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java82
-rw-r--r--config-model/src/test/derived/rankexpression/rank-profiles.cfg4
-rw-r--r--config-model/src/test/derived/rankexpression/rankexpression.sd38
-rw-r--r--config-model/src/test/derived/rankexpression/summary.cfg18
-rw-r--r--config-model/src/test/derived/rankexpression/summarymap.cfg26
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg2
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd2
-rw-r--r--config-model/src/test/examples/rankpropvars.sd8
-rw-r--r--config-model/src/test/examples/simple.sd2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java20
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java40
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java192
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java133
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java31
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java5
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java121
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java38
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java18
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java111
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java10
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
74 files changed, 1191 insertions, 470 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index d6b916680d8..bd94f67e4a7 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -323,7 +323,7 @@ public class DeployState implements ConfigDefinitionStore {
closeIgnoreException(reader.getReader());
}
}
- builder.build(logger, queryProfiles);
+ builder.build(logger);
return SearchDocumentModel.fromBuilderAndNames(builder, names);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java
index dd03cb8b2a7..dc59d9cb3e5 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java
@@ -5,11 +5,10 @@
*/
package com.yahoo.searchdefinition;
-import java.util.Arrays;
-import java.util.List;
+import com.yahoo.searchlib.rankingexpression.Reference;
+
import java.util.Optional;
import java.util.regex.Pattern;
-import java.util.stream.Collectors;
/**
* Utility methods for query, document and constant rank feature names
@@ -20,85 +19,16 @@ public class FeatureNames {
private static final Pattern identifierRegexp = Pattern.compile("[A-Za-z0-9_][A-Za-z0-9_-]*");
- /**
- * <p>Returns the given query, document or constant feature in canonical form.
- * A feature name consists of a feature type name (query, attribute or constant),
- * followed by one argument enclosed in quotes.
- * The argument may be an identifier or any string single or double quoted.</p>
- *
- * <p>Argument string values may not contain comma, single quote nor double quote characters.</p>
- *
- * <p><i>The canonical form use no quotes for arguments which are identifiers, and double quotes otherwise.</i></p>
- *
- * <p>Note that the above definition is not true for features in general, which accept any ranking expression
- * as argument.</p>
- *
- * @throws IllegalArgumentException if the feature name is not valid
- */
- // Note that this implementation is more general than what is described above:
- // It accepts any number of arguments and an optional output
- public static String canonicalize(String feature) {
- return canonicalizeIfValid(feature).orElseThrow(() ->
- new IllegalArgumentException("A feature name must be on the form query(name), attribute(name) or " +
- "constant(name), but was '" + feature + "'"
- ));
- }
-
- /**
- * Canonicalizes the given argument as in canonicalize, but returns empty instead of throwing an exception if
- * the argument is not a valid feature
- */
- public static Optional<String> canonicalizeIfValid(String feature) {
- int startParenthesis = feature.indexOf('(');
- if (startParenthesis < 0)
- return Optional.empty();
- int endParenthesis = feature.lastIndexOf(')');
- String featureType = feature.substring(0, startParenthesis);
- if ( ! ( featureType.equals("query") || featureType.equals("attribute") || featureType.equals("constant")))
- return Optional.empty();
- if (startParenthesis < 1) return Optional.of(feature); // No arguments
- if (endParenthesis < startParenthesis)
- return Optional.empty();
- String argumentString = feature.substring(startParenthesis + 1, endParenthesis);
- List<String> canonicalizedArguments =
- Arrays.stream(argumentString.split(","))
- .map(FeatureNames::canonicalizeArgument)
- .collect(Collectors.toList());
- return Optional.of(featureType + "(" +
- canonicalizedArguments.stream().collect(Collectors.joining(",")) +
- feature.substring(endParenthesis));
- }
-
- /** Canomicalizes a single argument */
- private static String canonicalizeArgument(String argument) {
- if (argument.startsWith("'")) {
- if ( ! argument.endsWith("'"))
- throw new IllegalArgumentException("Feature arguments starting by a single quote " +
- "must end by a single quote, but was \"" + argument + "\"");
- argument = argument.substring(1, argument.length() - 1);
- }
- if (argument.startsWith("\"")) {
- if ( ! argument.endsWith("\""))
- throw new IllegalArgumentException("Feature arguments starting by a double quote " +
- "must end by a double quote, but was '" + argument + "'");
- argument = argument.substring(1, argument.length() - 1);
- }
- if (identifierRegexp.matcher(argument).matches())
- return argument;
- else
- return "\"" + argument + "\"";
- }
-
- public static String asConstantFeature(String constantName) {
- return canonicalize("constant(\"" + constantName + "\")");
+ public static Reference asConstantFeature(String constantName) {
+ return Reference.simple("constant", quoteIfNecessary(constantName));
}
- public static String asAttributeFeature(String attributeName) {
- return canonicalize("attribute(\"" + attributeName + "\")");
+ public static Reference asAttributeFeature(String attributeName) {
+ return Reference.simple("attribute", quoteIfNecessary(attributeName));
}
- public static String asQueryFeature(String propertyName) {
- return canonicalize("query(\"" + propertyName + "\")");
+ public static Reference asQueryFeature(String propertyName) {
+ return Reference.simple("query", quoteIfNecessary(propertyName));
}
/**
@@ -106,15 +36,21 @@ public class FeatureNames {
* or empty if it is not a valid query, attribute or constant feature name
*/
public static Optional<String> argumentOf(String feature) {
- return canonicalizeIfValid(feature).map(f -> {
- int startParenthesis = f.indexOf("(");
- int endParenthesis = f.indexOf(")");
- String possiblyQuotedArgument = f.substring(startParenthesis + 1, endParenthesis);
- if (possiblyQuotedArgument.startsWith("\""))
- return possiblyQuotedArgument.substring(1, possiblyQuotedArgument.length() - 1);
- else
- return possiblyQuotedArgument;
- });
+ Optional<Reference> reference = Reference.simple(feature);
+ if ( ! reference.isPresent()) return Optional.empty();
+ if ( ! ( reference.get().name().equals("attribute") ||
+ reference.get().name().equals("constant") ||
+ reference.get().name().equals("query")))
+ return Optional.empty();
+
+ return Optional.of(reference.get().arguments().expressions().get(0).toString());
+ }
+
+ private static String quoteIfNecessary(String s) {
+ if (identifierRegexp.matcher(s).matches())
+ return s;
+ else
+ return "\"" + s + "\"";
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
new file mode 100644
index 00000000000..fcae756eab3
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -0,0 +1,156 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition;
+
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
+import com.yahoo.searchlib.rankingexpression.rule.NameNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * A context which only contains type information.
+ * This returns empty tensor types (double) for unknown features which are not
+ * query, attribute or constant features, as we do not have information about which such
+ * features exist (but we know those that exist are doubles).
+ *
+ * @author bratseth
+ */
+public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> {
+
+ private final Map<Reference, TensorType> featureTypes = new HashMap<>();
+
+ public MapEvaluationTypeContext(Collection<ExpressionFunction> functions) {
+ super(functions);
+ }
+
+ public MapEvaluationTypeContext(Map<String, ExpressionFunction> functions,
+ Map<String, String> bindings,
+ Map<Reference, TensorType> featureTypes) {
+ super(functions, bindings);
+ this.featureTypes.putAll(featureTypes);
+ }
+
+ public void setType(Reference reference, TensorType type) {
+ featureTypes.put(reference, type);
+ }
+
+ @Override
+ public TensorType getType(String reference) {
+ throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ }
+
+ @Override
+ public TensorType getType(Reference reference) {
+ Optional<String> binding = boundIdentifier(reference);
+ if (binding.isPresent()) {
+ try {
+ // This is not pretty, but changing to bind expressions rather
+ // than their string values requires deeper changes
+ return new RankingExpression(binding.get()).type(this);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ if (isSimpleFeature(reference)) {
+ // The argument may be a local identifier bound to the actual value
+ String argument = simpleArgument(reference.arguments()).get();
+ reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument));
+ return featureTypes.get(reference);
+ }
+
+ Optional<ExpressionFunction> function = functionInvocation(reference);
+ if (function.isPresent()) {
+ return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
+ }
+
+ // We do not know what this is - since we do not have complete knowledge abut the match features
+ // in Java we must assume this is a match feature and return the double type - which is the type of all
+ // all match features
+ return TensorType.empty;
+ }
+
+ /**
+ * Returns the binding if this reference is a simple identifier which is bound in this context.
+ * Returns empty otherwise.
+ */
+ private Optional<String> boundIdentifier(Reference reference) {
+ if ( ! reference.arguments().isEmpty()) return Optional.empty();
+ if ( reference.output() != null) return Optional.empty();
+ return Optional.ofNullable(bindings.get(reference.name()));
+ }
+
+ /**
+ * Return whether the reference (discarding the output) is a simple feature
+ * ("attribute(name)", "constant(name)" or "query(name)").
+ * We disregard the output because all outputs under a simple feature have the same type.
+ */
+ private boolean isSimpleFeature(Reference reference) {
+ Optional<String> argument = simpleArgument(reference.arguments());
+ if ( ! argument.isPresent()) return false;
+ return reference.name().equals("attribute") ||
+ reference.name().equals("constant") ||
+ reference.name().equals("query");
+ }
+
+ /**
+ * If these arguments contains one simple argument string, it is returned.
+ * Otherwise null is returned.
+ */
+ private Optional<String> simpleArgument(Arguments arguments) {
+ if (arguments.expressions().size() != 1) return Optional.empty();
+ ExpressionNode argument = arguments.expressions().get(0);
+
+ if ( ! (argument instanceof ReferenceNode)) return Optional.empty();
+ ReferenceNode refArgument = (ReferenceNode)argument;
+
+ if ( ! refArgument.reference().isIdentifier()) return Optional.empty();
+
+ return Optional.of(refArgument.getName());
+ }
+
+ private Optional<ExpressionFunction> functionInvocation(Reference reference) {
+ if (reference.output() != null) return Optional.empty();
+ ExpressionFunction function = functions().get(reference.name());
+ if (function == null) return Optional.empty();
+ if (function.arguments().size() != reference.arguments().size()) return Optional.empty();
+ return Optional.of(function);
+ }
+
+ /** Binds the given list of formal arguments to their actual values */
+ private Map<String, String> bind(List<String> formalArguments,
+ Arguments invocationArguments) {
+ Map<String, String> bindings = new HashMap<>(formalArguments.size());
+ for (int i = 0; i < formalArguments.size(); i++) {
+ String identifier = invocationArguments.expressions().get(i).toString();
+ identifier = super.bindings.getOrDefault(identifier, identifier);
+ bindings.put(formalArguments.get(i), identifier);
+ }
+ return bindings;
+ }
+
+ public Map<Reference, TensorType> featureTypes() {
+ return Collections.unmodifiableMap(featureTypes);
+ }
+
+ @Override
+ public MapEvaluationTypeContext withBindings(Map<String, String> bindings) {
+ if (bindings.isEmpty() && this.bindings.isEmpty()) return this;
+ return new MapEvaluationTypeContext(functions(), bindings, featureTypes);
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index bcbc7cc99e2..bd645422d50 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -2,15 +2,9 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.deploy.DeployState;
-import com.yahoo.io.reader.NamedReader;
-import com.yahoo.processing.request.CompoundName;
-import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.search.query.profile.config.QueryProfileXMLReader;
import com.yahoo.search.query.profile.types.FieldDescription;
import com.yahoo.search.query.profile.types.QueryProfileType;
-import com.yahoo.search.query.profile.types.TensorFieldType;
import com.yahoo.search.query.ranking.Diversity;
import com.yahoo.searchdefinition.document.SDField;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
@@ -18,8 +12,8 @@ import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.FeatureList;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
@@ -39,7 +33,9 @@ import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* Represents a rank profile - a named set of ranking settings
@@ -363,14 +359,14 @@ public class RankProfile implements Serializable, Cloneable {
/** Returns a read-only view of the summary features to use in this profile. This is never null */
public Set<ReferenceNode> getSummaryFeatures() {
- if (summaryFeatures!=null) return Collections.unmodifiableSet(summaryFeatures);
- if (getInherited()!=null) return getInherited().getSummaryFeatures();
+ if (summaryFeatures != null) return Collections.unmodifiableSet(summaryFeatures);
+ if (getInherited() != null) return getInherited().getSummaryFeatures();
return Collections.emptySet();
}
public void addSummaryFeature(ReferenceNode feature) {
- if (summaryFeatures==null)
- summaryFeatures=new LinkedHashSet<>();
+ if (summaryFeatures == null)
+ summaryFeatures = new LinkedHashSet<>();
summaryFeatures.add(feature);
}
@@ -585,8 +581,11 @@ public class RankProfile implements Serializable, Cloneable {
}
/**
- * Will take the parser-set textual ranking expressions and turn into objects
+ * Will take the parser-set textual ranking expressions and turn into ranking expression objects,
+ * if not already done
*/
+ // TODO: There doesn't appear to be any good reason to defer parsing of ranking expressions
+ // until this is called. Simplify by parsing them right away.
public void parseExpressions() {
try {
parseRankingExpressions();
@@ -604,20 +603,23 @@ public class RankProfile implements Serializable, Cloneable {
for (Map.Entry<String, Macro> e : getMacros().entrySet()) {
String macroName = e.getKey();
Macro macro = e.getValue();
- RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression());
- macro.setRankingExpression(expr);
- macro.setTextualExpression(expr.getRoot().toString());
+ if (macro.getRankingExpression() == null) {
+ RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression());
+ macro.setRankingExpression(expr);
+ macro.setTextualExpression(expr.getRoot().toString());
+ }
}
}
/**
* Passes ranking expressions on to parser
+ *
* @throws ParseException if either of the ranking expressions could not be parsed
*/
private void parseRankingExpressions() throws ParseException {
- if (getFirstPhaseRankingString() != null)
+ if (getFirstPhaseRankingString() != null && firstPhaseRanking == null)
setFirstPhaseRanking(parseRankingExpression("firstphase", getFirstPhaseRankingString()));
- if (getSecondPhaseRankingString() != null)
+ if (getSecondPhaseRankingString() != null && secondPhaseRanking == null)
setSecondPhaseRanking(parseRankingExpression("secondphase", getSecondPhaseRankingString()));
}
@@ -748,7 +750,9 @@ public class RankProfile implements Serializable, Cloneable {
* referable from this rank profile.
*/
public TypeContext typeContext(QueryProfileRegistry queryProfiles) {
- TypeMapContext context = new TypeMapContext();
+ MapEvaluationTypeContext context = new MapEvaluationTypeContext(getMacros().values().stream()
+ .map(Macro::asExpressionFunction)
+ .collect(Collectors.toList()));
// Add small constants
getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type()));
@@ -764,15 +768,18 @@ public class RankProfile implements Serializable, Cloneable {
for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) {
for (FieldDescription field : queryProfileType.declaredFields().values()) {
TensorType type = field.getType().asTensorType();
- String feature = FeatureNames.asQueryFeature(field.getName());
- TensorType existingType = context.getType(feature);
+ Optional<Reference> feature = Reference.simple(field.getName());
+ if ( ! feature.isPresent() || ! feature.get().name().equals("query")) continue;
+
+ TensorType existingType = context.getType(feature.get());
if (existingType != null)
type = existingType.dimensionwiseGeneralizationWith(type).orElseThrow( () ->
new IllegalArgumentException(queryProfileType + " contains query feature " + feature +
" with type " + field.getType().asTensorType() +
", but this is already defined " +
- "in another query profile with type " + context.getType(feature)));
- context.setType(feature, type);
+ "in another query profile with type " +
+ context.getType(feature.get())));
+ context.setType(feature.get(), type);
}
}
@@ -910,7 +917,7 @@ public class RankProfile implements Serializable, Cloneable {
*/
public static class Macro implements Serializable, Cloneable {
- private String name=null;
+ private final String name;
private String textualExpression=null;
private RankingExpression expression=null;
private List<String> formalParams = new ArrayList<>();
@@ -955,7 +962,7 @@ public class RankProfile implements Serializable, Cloneable {
return inline && formalParams.size() == 0; // only inline no-arg macros;
}
- public ExpressionFunction toExpressionMacro() {
+ public ExpressionFunction asExpressionFunction() {
return new ExpressionFunction(getName(), getFormalParams(), getRankingExpression());
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
index 762c0fec838..e7cd21ac834 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
@@ -18,6 +18,7 @@ import com.yahoo.searchdefinition.parser.TokenMgrError;
import com.yahoo.searchdefinition.processing.Processing;
import com.yahoo.vespa.documentmodel.DocumentModel;
import com.yahoo.vespa.model.container.search.QueryProfiles;
+import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.io.IOException;
@@ -34,7 +35,6 @@ import java.util.List;
* expressions, using the setRankXXX() methods, 3) invoke the {@link #build()} method, and 4) retrieve the built
* search objects using the {@link #getSearch(String)} method.
*/
-// TODO: This should be cleaned up and more or maybe completely taken over by MockApplicationPackage
public class SearchBuilder {
private final DocumentTypeManager docTypeMgr = new DocumentTypeManager();
@@ -154,7 +154,7 @@ public class SearchBuilder {
} catch (TokenMgrError e) {
throw new ParseException("Unknown symbol: " + e.getMessage());
} catch (ParseException pe) {
- throw new ParseException(stream.formatException(pe.getMessage()));
+ throw new ParseException(stream.formatException(Exceptions.toMessageString(pe)));
}
return importRawSearch(search);
}
@@ -196,11 +196,7 @@ public class SearchBuilder {
* @throws IllegalStateException Thrown if this method has already been called.
*/
public void build() {
- build(new BaseDeployLogger(), new QueryProfiles());
- }
-
- public void build(DeployLogger logger) {
- build(logger, new QueryProfiles());
+ build(new BaseDeployLogger());
}
/**
@@ -209,12 +205,10 @@ public class SearchBuilder {
*
* @throws IllegalStateException Thrown if this method has already been called.
* @param deployLogger The logger to use during build
- * @param queryProfiles The query profiles contained in the application this search is part of.
*/
- public void build(DeployLogger deployLogger, QueryProfiles queryProfiles) {
- if (isBuilt) {
- throw new IllegalStateException("Searches already built.");
- }
+ public void build(DeployLogger deployLogger) {
+ if (isBuilt) throw new IllegalStateException("Model already built");
+
List<Search> built = new ArrayList<>();
List<SDDocumentType> sdocs = new ArrayList<>();
sdocs.add(SDDocumentType.VESPA_DOCUMENT);
@@ -240,7 +234,7 @@ public class SearchBuilder {
for (Search search : new SearchOrderer().order(searchList)) {
new FieldOperationApplierForSearch().process(search);
// These two needed for a couple of old unit tests, ideally these are just read from app
- process(search, deployLogger, queryProfiles);
+ process(search, deployLogger, new QueryProfiles(queryProfileRegistry));
built.add(search);
}
builder.addToModel(searchList);
@@ -254,8 +248,6 @@ public class SearchBuilder {
/**
* Processes and returns the given {@link Search} object. This method has been factored out of the {@link
* #build()} method so that subclasses can choose not to build anything.
- *
- * @param search The object to build.
*/
protected void process(Search search, DeployLogger deployLogger, QueryProfiles queryProfiles) {
Processing.process(search, deployLogger, rankProfileRegistry, queryProfiles);
@@ -352,7 +344,7 @@ public class SearchBuilder {
rankProfileRegistry,
queryprofileRegistry);
builder.importFile(fileName);
- builder.build(deployLogger, new QueryProfiles());
+ builder.build(deployLogger);
return builder;
}
@@ -368,7 +360,7 @@ public class SearchBuilder {
for (Iterator<Path> i = Files.list(new File(dir).toPath()).filter(p -> p.getFileName().toString().endsWith(".sd")).iterator(); i.hasNext(); ) {
builder.importFile(i.next());
}
- builder.build(new BaseDeployLogger(), new QueryProfiles());
+ builder.build(new BaseDeployLogger());
return builder;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java
deleted file mode 100644
index 40e9db1413f..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchdefinition;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * A context which only contains type information.
- *
- * @author bratseth
- */
-public class TypeMapContext implements TypeContext {
-
- private final Map<String, TensorType> featureTypes = new HashMap<>();
-
- public void setType(String name, TensorType type) {
- featureTypes.put(FeatureNames.canonicalize(name), type);
- }
-
- @Override
- public TensorType getType(String name) {
- return featureTypes.get(FeatureNames.canonicalize(name));
- }
-
- /** Returns an unmodifiable map of the bindings in this */
- public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index ea02f960800..b02362154d9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -188,7 +188,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
if (macros.isEmpty()) return;
Map<String, ExpressionFunction> expressionMacros = new LinkedHashMap<>();
for (Map.Entry<String, RankProfile.Macro> macro : macros.entrySet()) {
- expressionMacros.put(macro.getKey(), macro.getValue().toExpressionMacro());
+ expressionMacros.put(macro.getKey(), macro.getValue().asExpressionFunction());
}
Map<String, String> macroProperties = new LinkedHashMap<>();
@@ -223,7 +223,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
// Is the feature a macro?
if (context.getFunction(referenceNode.getName()) != null) {
context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()),
- referenceNode.toString(context, null, null));
+ referenceNode.toString(context, null, null));
ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput());
macroSummaryFeatures.put(referenceNode.getName(), newReferenceNode);
i.remove(); // Will add the expanded one in next block
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 2b997aa25f2..f16697b5ba6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -208,6 +208,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
"' of type " + requiredType + " but this macro is not present in " +
profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
if ( actualType == null)
throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java
index ee65c9bec02..cc634abef01 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java
@@ -13,7 +13,7 @@ import com.yahoo.vespa.indexinglanguage.expressions.OutputExpression;
import com.yahoo.vespa.model.container.search.QueryProfiles;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen Hult</a>
+ * @author Simon Thoresen Hult
*/
public class IndexingValues extends Processor {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index 90183848094..061a803cb48 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -76,8 +76,9 @@ public class Processing {
ImportedFieldsInSummayValidator::new,
FastAccessValidator::new,
ReservedMacroNames::new,
+ RankingExpressionTypeValidator::new,
- // These two should be last.
+ // These should be last.
IndexingValidation::new,
IndexingValues::new);
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java
new file mode 100644
index 00000000000..fa2610d77a1
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java
@@ -0,0 +1,82 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+
+/**
+ * Validates the types of all ranking expressions under a search instance:
+ * Some operators constrain the types of inputs, and first-and second-phase expressions
+ * must return scalar values. In addition, the existence of all referred attribute, query and constant
+ * features is ensured.
+ *
+ * @author bratseth
+ */
+public class RankingExpressionTypeValidator extends Processor {
+
+ private final QueryProfileRegistry queryProfiles;
+
+ public RankingExpressionTypeValidator(Search search,
+ DeployLogger deployLogger,
+ RankProfileRegistry rankProfileRegistry,
+ QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ this.queryProfiles = queryProfiles.getRegistry();
+ }
+
+ @Override
+ public void process() {
+ for (RankProfile profile : rankProfileRegistry.allRankProfiles()) {
+ try {
+ validate(profile);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("In " + search + ", " + profile, e);
+ }
+ }
+ }
+
+ /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */
+ private void validate(RankProfile profile) {
+ profile.parseExpressions();
+ TypeContext context = profile.typeContext(queryProfiles);
+ profile.getSummaryFeatures().forEach(f -> ensureValid(f, "summary feature " + f, context));
+ ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context);
+ ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context);
+ }
+
+ private TensorType ensureValid(RankingExpression expression, String expressionDescription, TypeContext context) {
+ if (expression == null) return null;
+ return ensureValid(expression.getRoot(), expressionDescription, context);
+ }
+
+ private TensorType ensureValid(ExpressionNode expression, String expressionDescription, TypeContext context) {
+ TensorType type;
+ try {
+ type = expression.type(context);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("The " + expressionDescription + " is invalid", e);
+ }
+ if (type == null) // Not expected to happen
+ throw new IllegalStateException("Could not determine the type produced by " + expressionDescription);
+ return type;
+ }
+
+ private void ensureValidDouble(RankingExpression expression, String expressionDescription, TypeContext context) {
+ if (expression == null) return;
+ TensorType type = ensureValid(expression, expressionDescription, context);
+ if ( ! type.equals(TensorType.empty))
+ throw new IllegalArgumentException("The " + expressionDescription + " must produce a double " +
+ "(a tensor with no dimensions), but produces " + type);
+ }
+
+}
diff --git a/config-model/src/test/derived/rankexpression/rank-profiles.cfg b/config-model/src/test/derived/rankexpression/rank-profiles.cfg
index e890b75770b..f5652c31d2a 100644
--- a/config-model/src/test/derived/rankexpression/rank-profiles.cfg
+++ b/config-model/src/test/derived/rankexpression/rank-profiles.cfg
@@ -24,7 +24,7 @@ rankprofile[0].fef.property[10].value "4"
rankprofile[0].fef.property[11].name "vespa.dump.feature"
rankprofile[0].fef.property[11].value "attribute(foo1).out"
rankprofile[0].fef.property[12].name "vespa.dump.feature"
-rankprofile[0].fef.property[12].value "attribute(bar1.out)"
+rankprofile[0].fef.property[12].value "attribute(bar1)"
rankprofile[0].fef.property[13].name "vespa.dump.feature"
rankprofile[0].fef.property[13].value "attribute(foo2).out"
rankprofile[0].fef.property[14].name "vespa.dump.feature"
@@ -64,7 +64,7 @@ rankprofile[2].fef.property[2].value "10 + feature(arg1).out.out"
rankprofile[2].fef.property[3].name "vespa.summary.feature"
rankprofile[2].fef.property[3].value "attribute(foo1).out"
rankprofile[2].fef.property[4].name "vespa.summary.feature"
-rankprofile[2].fef.property[4].value "attribute(bar1.out)"
+rankprofile[2].fef.property[4].value "attribute(bar1)"
rankprofile[2].fef.property[5].name "vespa.summary.feature"
rankprofile[2].fef.property[5].value "attribute(foo2).out"
rankprofile[2].fef.property[6].name "vespa.summary.feature"
diff --git a/config-model/src/test/derived/rankexpression/rankexpression.sd b/config-model/src/test/derived/rankexpression/rankexpression.sd
index 8ed1f2bab4c..d3e0057cfe1 100644
--- a/config-model/src/test/derived/rankexpression/rankexpression.sd
+++ b/config-model/src/test/derived/rankexpression/rankexpression.sd
@@ -5,12 +5,10 @@ search rankexpression {
field artist type string {
indexing: summary | index
- # index-to: artist, default
}
field title type string {
indexing: summary | index
- # index-to: title, default
}
field surl type string {
@@ -21,6 +19,38 @@ search rankexpression {
indexing: summary | attribute
}
+ field foo1 type int {
+ indexing: attribute
+ }
+
+ field foo2 type int {
+ indexing: attribute
+ }
+
+ field foo3 type int {
+ indexing: attribute
+ }
+
+ field foo4 type int {
+ indexing: attribute
+ }
+
+ field bar1 type int {
+ indexing: attribute
+ }
+
+ field bar2 type int {
+ indexing: attribute
+ }
+
+ field bar3 type int {
+ indexing: attribute
+ }
+
+ field bar4 type int {
+ indexing: attribute
+ }
+
}
rank-profile default {
@@ -33,7 +63,7 @@ search rankexpression {
expression: if(3>2,4,2)
rerank-count: 10
}
- rank-features: attribute(foo1).out attribute(bar1.out)
+ rank-features: attribute(foo1).out attribute(bar1)
rank-features { attribute(foo2).out attribute(bar2).out }
rank-features {
attribute(foo3).out attribute(bar3).out }
@@ -65,7 +95,7 @@ search rankexpression {
file:rankexpression
}
}
- summary-features: attribute(foo1).out attribute(bar1.out)
+ summary-features: attribute(foo1).out attribute(bar1)
summary-features { attribute(foo2).out attribute(bar2).out }
summary-features {
attribute(foo3).out attribute(bar3).out }
diff --git a/config-model/src/test/derived/rankexpression/summary.cfg b/config-model/src/test/derived/rankexpression/summary.cfg
index 00df2e87144..9752a9f55e3 100644
--- a/config-model/src/test/derived/rankexpression/summary.cfg
+++ b/config-model/src/test/derived/rankexpression/summary.cfg
@@ -15,9 +15,25 @@ classes[0].fields[5].name "summaryfeatures"
classes[0].fields[5].type "featuredata"
classes[0].fields[6].name "documentid"
classes[0].fields[6].type "longstring"
-classes[1].id 1787488393
+classes[1].id 1736696699
classes[1].name "attributeprefetch"
classes[1].fields[0].name "year"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo1"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo2"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo3"
+classes[].fields[].type "integer"
+classes[].fields[].name "foo4"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar1"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar2"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar3"
+classes[].fields[].type "integer"
+classes[].fields[].name "bar4"
classes[1].fields[0].type "integer"
classes[1].fields[1].name "rankfeatures"
classes[1].fields[1].type "featuredata"
diff --git a/config-model/src/test/derived/rankexpression/summarymap.cfg b/config-model/src/test/derived/rankexpression/summarymap.cfg
index c810f7282ba..21e6cdf346f 100644
--- a/config-model/src/test/derived/rankexpression/summarymap.cfg
+++ b/config-model/src/test/derived/rankexpression/summarymap.cfg
@@ -7,4 +7,28 @@ override[1].command "rankfeatures"
override[1].arguments ""
override[2].field "summaryfeatures"
override[2].command "summaryfeatures"
-override[2].arguments "" \ No newline at end of file
+override[2].arguments ""
+override[].field "foo1"
+override[].command "attribute"
+override[].arguments "foo1"
+override[].field "foo2"
+override[].command "attribute"
+override[].arguments "foo2"
+override[].field "foo3"
+override[].command "attribute"
+override[].arguments "foo3"
+override[].field "foo4"
+override[].command "attribute"
+override[].arguments "foo4"
+override[].field "bar1"
+override[].command "attribute"
+override[].arguments "bar1"
+override[].field "bar2"
+override[].command "attribute"
+override[].arguments "bar2"
+override[].field "bar3"
+override[].command "attribute"
+override[].arguments "bar3"
+override[].field "bar4"
+override[].command "attribute"
+override[].arguments "bar4" \ No newline at end of file
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index 2b231e0cda2..b6ad5372c05 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -35,7 +35,7 @@ rankprofile[3].name "profile2"
rankprofile[3].fef.property[0].name "vespa.rank.firstphase"
rankprofile[3].fef.property[0].value "rankingExpression(firstphase)"
rankprofile[3].fef.property[1].name "rankingExpression(firstphase).rankingScript"
-rankprofile[3].fef.property[1].value "reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x)"
+rankprofile[3].fef.property[1].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)"
rankprofile[3].fef.property[2].name "vespa.type.attribute.f2"
rankprofile[3].fef.property[2].value "tensor(x[2],y[])"
rankprofile[3].fef.property[3].name "vespa.type.attribute.f3"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index a6a9a98db3a..3d64f6b807e 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -28,7 +28,7 @@ search tensor {
rank-profile profile2 {
first-phase {
- expression: matmul(attribute(f4), diag(x[2],y[2],z[3]), x)
+ expression: sum(matmul(attribute(f4), diag(x[2],y[2],z[3]), x))
}
}
diff --git a/config-model/src/test/examples/rankpropvars.sd b/config-model/src/test/examples/rankpropvars.sd
index 40f9e73f35a..28959edbc09 100644
--- a/config-model/src/test/examples/rankpropvars.sd
+++ b/config-model/src/test/examples/rankpropvars.sd
@@ -18,8 +18,8 @@ first-phase {
second-phase {
expression {
if (attribute(artist) == query(testvar1),
- 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist),
- 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
+ 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist),
+ 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
}
}
@@ -42,8 +42,8 @@ first-phase {
second-phase {
expression {
if (attribute(artist) == query(testvar1),
- 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist),
- 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
+ 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist),
+ 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title))
}
}
}
diff --git a/config-model/src/test/examples/simple.sd b/config-model/src/test/examples/simple.sd
index 4fda7f5039e..96b0fa98098 100644
--- a/config-model/src/test/examples/simple.sd
+++ b/config-model/src/test/examples/simple.sd
@@ -116,7 +116,7 @@ search simple {
first-phase {
keep-rank-count:200
rank-score-drop-limit: -13.0
- expression: attribute(year)
+ expression: attribute(popularity)
}
second-phase {
rerank-count: 99
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java
index 1f60ad870ec..aa01070d296 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java
@@ -18,17 +18,6 @@ import static org.junit.Assert.assertFalse;
public class FeatureNamesTestCase {
@Test
- public void testCanonicalization() {
- assertFalse(FeatureNames.canonicalizeIfValid("foo").isPresent());
- assertEquals("query(bar)", FeatureNames.canonicalize("query(bar)"));
- assertEquals("query(bar)", FeatureNames.canonicalize("query('bar')"));
- assertEquals("constant(bar)", FeatureNames.canonicalize("constant(\"bar\")"));
- assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query(ba.r)"));
- assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query('ba.r')"));
- assertEquals("attribute(\"ba.r\")", FeatureNames.canonicalize("attribute(\"ba.r\")"));
- }
-
- @Test
public void testArgument() {
assertFalse(FeatureNames.argumentOf("foo(bar)").isPresent());
assertFalse(FeatureNames.argumentOf("foo(bar.baz)").isPresent());
@@ -42,17 +31,20 @@ public class FeatureNamesTestCase {
@Test
public void testConstantFeature() {
- assertEquals("constant(\"foo/bar\")", FeatureNames.asConstantFeature("foo/bar"));
+ assertEquals("constant(\"foo/bar\")",
+ FeatureNames.asConstantFeature("foo/bar").toString());
}
@Test
public void testAttributeFeature() {
- assertEquals("attribute(foo)", FeatureNames.asAttributeFeature("foo"));
+ assertEquals("attribute(foo)",
+ FeatureNames.asAttributeFeature("foo").toString());
}
@Test
public void testQueryFeature() {
- assertEquals("query(\"foo.bar\")", FeatureNames.asQueryFeature("foo.bar"));
+ assertEquals("query(\"foo.bar\")",
+ FeatureNames.asQueryFeature("foo.bar").toString());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 442c8bd41bd..11093d9f008 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -135,13 +135,13 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
@Test
public void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseException {
RankProfileRegistry registry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(registry);
+ SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes());
builder.importString("search test {\n" +
" document test { } \n" +
" rank-profile p1 {}\n" +
" rank-profile p2 {}\n" +
"}");
- builder.build(new BaseDeployLogger(), setupQueryProfileTypes());
+ builder.build(new BaseDeployLogger());
Search search = builder.getSearch();
assertEquals(4, registry.allRankProfiles().size());
@@ -151,7 +151,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search);
}
- private static QueryProfiles setupQueryProfileTypes() {
+ private static QueryProfileRegistry setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
QueryProfileType type = new QueryProfileType(new ComponentId("testtype"));
@@ -164,7 +164,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
type.addField(new FieldDescription("ranking.features.query(numeric)",
FieldType.fromString("integer", typeRegistry)), typeRegistry);
typeRegistry.register(type);
- return new QueryProfiles(registry);
+ return registry;
}
private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index e94880e61c7..82b9f5ac043 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -207,6 +207,9 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
builder.importString(
"search test {\n" +
" document test { \n" +
+ " field rating_yelp type int {" +
+ " indexing: attribute" +
+ " }" +
" }\n" +
" \n" +
" rank-profile test {\n" +
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 5100ac15c40..ed1b00e2875 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -2,7 +2,10 @@
package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
+import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.types.FieldDescription;
+import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
@@ -149,11 +152,12 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
censorBindingHash(testRankProperties.get(4).toString()));
}
-
@Test
public void testNeuralNetworkSetup() throws ParseException {
+ // Note: the type assigned to query profile and constant tensors here is not the correct type
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])");
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString(
"search test {\n" +
" document test { \n" +
@@ -176,13 +180,28 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
" expression: sum(final_layer)\n" +
" }\n" +
" }\n" +
- "\n" +
+ " constant W_hidden {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_input {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant W_final {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_final {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
- new QueryProfileRegistry(),
+ queryProfiles,
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
testRankProperties.get(0).toString());
@@ -198,6 +217,17 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
testRankProperties.get(5).toString());
}
+ private QueryProfileRegistry queryProfileWith(String field, String type) {
+ QueryProfileType queryProfileType = new QueryProfileType("root");
+ queryProfileType.addField(new FieldDescription(field, type));
+ QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry();
+ queryProfileRegistry.getTypeRegistry().register(queryProfileType);
+ QueryProfile profile = new QueryProfile("default");
+ profile.setType(queryProfileType);
+ queryProfileRegistry.register(profile);
+ return queryProfileRegistry;
+ }
+
private String censorBindingHash(String s) {
StringBuilder b = new StringBuilder();
boolean areInHash = false;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 800697b3430..0ce6129ef7f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -38,7 +38,8 @@ class RankProfileSearchFixture {
RankProfileSearchFixture(ApplicationPackage applicationpackage, QueryProfileRegistry queryProfileRegistry,
String rankProfiles, String constant, String field)
throws ParseException {
- SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, new QueryProfileRegistry());
+ this.queryProfileRegistry = queryProfileRegistry;
+ SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, queryProfileRegistry);
String sdContent = "search test {\n" +
" " + (constant != null ? constant : "") + "\n" +
" document test {\n" +
@@ -50,7 +51,6 @@ class RankProfileSearchFixture {
builder.importString(sdContent);
builder.build();
search = builder.getSearch();
- this.queryProfileRegistry = queryProfileRegistry;
}
public void assertFirstPhaseExpression(String expExpression, String rankProfile) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java
new file mode 100644
index 00000000000..056d2ad534b
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java
@@ -0,0 +1,192 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.yolean.Exceptions;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import static com.yahoo.config.model.test.TestUtil.joinLines;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class RankingExpressionTypeValidatorTestCase {
+
+ @Test
+ public void tensorFirstPhaseMustProduceDouble() throws Exception {
+ try {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void tensorSecondPhaseMustProduceDouble() throws Exception {
+ try {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(attribute(a))",
+ " }",
+ " second-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception {
+ try {
+ SearchBuilder searchBuilder = new SearchBuilder();
+ searchBuilder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(if(1>0, attribute(a), attribute(b)))",
+ " }",
+ " }",
+ "}"
+ ));
+ searchBuilder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testMacroInvocationTypes() throws Exception {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " macro macro1(attribute_to_use) {",
+ " expression: attribute(attribute_to_use)",
+ " }",
+ " summary-features {",
+ " macro1(a)",
+ " macro1(b)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ RankProfile profile =
+ builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[],y[])"),
+ summaryFeatures(profile).get("macro1(a)").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("macro1(b)").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void testTensorMacroInvocationTypes_Nested() throws Exception {
+ SearchBuilder builder = new SearchBuilder();
+ builder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " macro return_a() {",
+ " expression: return_first(attribute(a), attribute(b))",
+ " }",
+ " macro return_b() {",
+ " expression: return_second(attribute(a), attribute(b))",
+ " }",
+ " macro return_first(e1, e2) {",
+ " expression: e1",
+ " }",
+ " macro return_second(e1, e2) {",
+ " expression: return_first(e2, e1)",
+ " }",
+ " summary-features {",
+ " return_a",
+ " return_b",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build();
+ RankProfile profile =
+ builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[],y[])"),
+ summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ private Map<String, ReferenceNode> summaryFeatures(RankProfile profile) {
+ return profile.getSummaryFeatures().stream().collect(Collectors.toMap(f -> f.toString(), f -> f));
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 4693ac5cf4d..96795d2b08f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -73,7 +73,7 @@ public class RankingExpressionWithTensorFlowTestCase {
public void testTensorFlowReferenceWithQueryFeature() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -107,7 +107,7 @@ public class RankingExpressionWithTensorFlowTestCase {
public void testTensorFlowReferenceWithFeatureCombination() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index b001db69768..054c9220225 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,98 +17,129 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
import java.util.List;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
public class TensorTransformTestCase extends SearchDefinitionTestCase {
@Test
public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException {
- assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)");
- assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)");
- assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))");
- assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))");
- assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))");
- assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)");
- assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)");
- assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)");
- assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)");
- assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)");
+ assertTransformedExpression("max(1.0,2.0)",
+ "max(1.0,2.0)");
+ assertTransformedExpression("min(attribute(double_field),x)",
+ "min(attribute(double_field),x)");
+ assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))",
+ "max(attribute(double_field),attribute(double_array_field))");
+ assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))",
+ "min(attribute(tensor_field_1),attribute(double_field))");
+ assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)",
+ "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)");
+ assertTransformedExpression("min(constant(test_constant_tensor),1.0)",
+ "min(test_constant_tensor,1.0)");
+ assertTransformedExpression("max(constant(base_constant_tensor),1.0)",
+ "max(base_constant_tensor,1.0)");
+ assertTransformedExpression("min(constant(file_constant_tensor),1.0)",
+ "min(constant(file_constant_tensor),1.0)");
+ assertTransformedExpression("max(query(q),1.0)",
+ "max(query(q),1.0)");
+ assertTransformedExpression("max(query(n),1.0)",
+ "max(query(n),1.0)");
}
@Test
public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException {
- assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)");
- assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)");
- assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)");
- assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
- assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
- assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error.
- assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)");
+ assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)",
+ "max(attribute(tensor_field_1),x)");
+ assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)",
+ "1 + max(attribute(tensor_field_1),x)");
+ assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)",
+ "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)");
+ assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)",
+ "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)");
+ assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)",
+ "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)");
+ assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)",
+ "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error.
+ assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)",
+ "max(max(attribute(tensor_field_2),x),y)");
}
@Test
public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException {
- assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)");
- assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)");
- assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)");
+ assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)",
+ "max(test_constant_tensor,x)");
+ assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)",
+ "max(base_constant_tensor,x)");
+ assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)",
+ "min(constant(file_constant_tensor),x)");
}
@Test
public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException {
- assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)");
- assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)");
- assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)");
- assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
- assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)");
+ assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)",
+ "min(attribute(double_field) + attribute(tensor_field_1),x)");
+ assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)",
+ "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)");
+ assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)",
+ "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)");
+ assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)",
+ "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
+ assertTransformedExpression("reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)",
+ "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
}
@Test
public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException {
- assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)");
- assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)");
- assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)");
- assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)");
+ assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)",
+ "max(tensorFromLabels(attribute(double_array_field)),double_array_field)");
+ assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)",
+ "max(tensorFromLabels(attribute(double_array_field),x),x)");
+ assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)",
+ "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)");
+ assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)",
+ "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)");
}
@Test
public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException {
- assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)");
- assertContainsExpression("max(query(n),x)", "max(query(n),x)");
+ assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)");
+ assertTransformedExpression("max(query(n),x)", "max(query(n),x)");
}
@Test
public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException {
- assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)");
- assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)");
- assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)");
- assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)");
+ assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)",
+ "max(returns_tensor,x)");
+ assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)",
+ "max(wraps_returns_tensor,x)");
+ assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)",
+ "max(tensor_inheriting,x)");
+ assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)",
+ "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)");
}
- private void assertContainsExpression(String expr, String transformedExpression) throws ParseException {
- assertTrue("Expected expression '" + transformedExpression + "' found",
- containsExpression(expr, transformedExpression));
- }
-
- private boolean containsExpression(String expr, String transformedExpression) throws ParseException {
- for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) {
+ private void assertTransformedExpression(String expected, String original) throws ParseException {
+ for (Pair<String, String> rankPropertyExpression : buildSearch(original)) {
String rankProperty = rankPropertyExpression.getFirst();
if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) {
String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ",""));
- return rankExpression.equals(transformedExpression);
+ assertEquals(expected, rankExpression);
+ return;
}
}
- return false;
+ fail("No 'rankingExpression(firstphase).rankingScript' property produced");
}
private List<Pair<String, String>> buildSearch(String expression) throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ QueryProfileRegistry queryProfiles = setupQueryProfileTypes();
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString(
"search test {\n" +
" document test { \n" +
@@ -167,16 +198,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
" }\n" +
" }\n" +
"}\n");
- builder.build(new BaseDeployLogger(), setupQueryProfileTypes());
+ builder.build(new BaseDeployLogger());
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
- new QueryProfileRegistry(),
+ queryProfiles,
new AttributeFields(s)).configProperties();
return testRankProperties;
}
- private static QueryProfiles setupQueryProfileTypes() {
+ private static QueryProfileRegistry setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
QueryProfileType type = new QueryProfileType(new ComponentId("testtype"));
@@ -185,7 +216,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
type.addField(new FieldDescription("ranking.features.query(n)",
FieldType.fromString("integer", typeRegistry)), typeRegistry);
typeRegistry.register(type);
- return new QueryProfiles(registry);
+ return registry;
}
private String censorBindingHash(String s) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 2e2858da238..262aba89f27 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.text.Utf8;
@@ -11,9 +12,9 @@ import java.security.NoSuchAlgorithmException;
import java.util.*;
/**
- * <p>A function defined by a ranking expression</p>
+ * A function defined by a ranking expression
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
* @author bratseth
*/
public class ExpressionFunction {
@@ -23,7 +24,7 @@ public class ExpressionFunction {
private final RankingExpression body;
/**
- * <p>Constructs a new function</p>
+ * Constructs a new function
*
* @param name the name of this function
* @param arguments its argument names
@@ -43,28 +44,27 @@ public class ExpressionFunction {
public RankingExpression getBody() { return body; }
/**
- * <p>Create and return an instance of this function based on the given
- * arguments. If function calls are nested, this call might produce
- * additional scripts.</p>
+ * Creates and returns an instance of this function based on the given
+ * arguments. If function calls are nested, this call may produce
+ * additional functions.
*
* @param context the context used to expand this
- * @param arguments the arguments to instantiate on.
+ * @param argumentValues the arguments to instantiate on.
* @param path the expansion path leading to this.
* @return the script function instance created.
*/
- public Instance expand(SerializationContext context, List<ExpressionNode> arguments, Deque<String> path) {
+ public Instance expand(SerializationContext context, List<ExpressionNode> argumentValues, Deque<String> path) {
Map<String, String> argumentBindings = new HashMap<>();
- for (int i = 0; i < this.arguments.size() && i < arguments.size(); ++i) {
- argumentBindings.put(this.arguments.get(i), arguments.get(i).toString(context, path, null));
+ for (int i = 0; i < arguments.size() && i < arguments.size(); ++i) {
+ argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(context, path, null));
}
- return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.createBinding(argumentBindings), path, null));
+ return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.withBindings(argumentBindings), path, null));
}
/**
* Returns a symbolic string that represents this function with a given
* list of arguments. The arguments are mangled by hashing the string
- * representation of the argument expressions, so we might need to revisit
- * this if we start seeing collisions.
+ * representation of the argument expressions.
*
* @param argumentBindings the bound arguments to include in the symbolic name.
* @return the symbolic name for an instance of this function
@@ -85,8 +85,8 @@ public class ExpressionFunction {
/**
- * <p>Returns a more unique hash code than what Java's own {@link
- * String#hashCode()} method would produce.</p>
+ * Returns a more unique hash code than what Java's own {@link
+ * String#hashCode()} method would produce.
*
* @param str The string to hash.
* @return A 64 bit long hash code.
@@ -136,4 +136,5 @@ public class ExpressionFunction {
}
}
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java
index 49466f1974d..f0532d9d433 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java
@@ -91,8 +91,8 @@ public class FeatureList implements Iterable<ReferenceNode> {
/**
* Returns the feature at the given index.
*
- * @param i The index of the feature to return.
- * @return The featuer at the given index.
+ * @param i the index of the feature to return.
+ * @return the feature at the given index.
*/
public ReferenceNode get(int i) {
return features.get(i);
@@ -137,4 +137,5 @@ public class FeatureList implements Iterable<ReferenceNode> {
public Iterator<ReferenceNode> iterator() {
return features.iterator();
}
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
index c8d90e8c4e8..6b2422d7cb2 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
@@ -244,10 +244,6 @@ public class RankingExpression implements Serializable {
* @return a list of named rank properties required to implement this expression.
*/
public Map<String, String> getRankProperties(List<ExpressionFunction> macros) {
- Map<String, ExpressionFunction> arg = new HashMap<>();
- for (ExpressionFunction function : macros) {
- arg.put(function.getName(), function);
- }
Deque<String> path = new LinkedList<>();
SerializationContext context = new SerializationContext(macros);
String serializedRoot = root.toString(context, path, null);
@@ -272,7 +268,7 @@ public class RankingExpression implements Serializable {
*
* @throws IllegalArgumentException if this expression is not type correct in this context
*/
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return root.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
new file mode 100644
index 00000000000..6277721e8f5
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
@@ -0,0 +1,121 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression;
+
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Deque;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * A reference to a feature, function, or value in ranking expressions
+ *
+ * @author bratseth
+ */
+public class Reference extends TypeContext.Name {
+
+ private final String name;
+ private final Arguments arguments;
+
+ /**
+ * The output, or null if none
+ */
+ private final String output;
+
+ public Reference(String name, Arguments arguments, String output) {
+ super(name);
+ Objects.requireNonNull(name, "name cannot be null");
+ Objects.requireNonNull(arguments, "arguments cannot be null");
+ this.name = name;
+ this.arguments = arguments;
+ this.output = output;
+ }
+
+ public String name() { return name; }
+
+ public Arguments arguments() { return arguments; }
+
+ public String output() { return output; }
+
+ /**
+ * Creates a reference to a simple feature consisting of a name and a single argument
+ */
+ public static Reference simple(String name, String argumentValue) {
+ return new Reference(name,
+ new Arguments(new ReferenceNode(argumentValue)),
+ null);
+ }
+
+ /**
+ * Returns the given simple feature as a reference, or empty if it is not a valid simple
+ * feature string on the form name(argument).
+ */
+ public static Optional<Reference> simple(String feature) {
+ int startParenthesis = feature.indexOf('(');
+ if (startParenthesis < 0)
+ return Optional.empty();
+ int endParenthesis = feature.lastIndexOf(')');
+ String featureName = feature.substring(0, startParenthesis);
+ if (startParenthesis < 1 || endParenthesis < startParenthesis) return Optional.empty();
+ String argument = feature.substring(startParenthesis + 1, endParenthesis);
+ if (argument.startsWith("'") || argument.startsWith("\""))
+ argument = argument.substring(1);
+ if (argument.endsWith("'") || argument.endsWith("\""))
+ argument = argument.substring(0, argument.length() - 1);
+ return Optional.of(simple(featureName, argument));
+ }
+
+ /**
+ * Returns whether this is a simple identifier - no arguments or output
+ */
+ public boolean isIdentifier() {
+ return this.arguments.expressions().size() == 0 && output == null;
+ }
+
+ public Reference withArguments(Arguments arguments) {
+ return new Reference(name, arguments, output);
+ }
+
+ public Reference withOutput(String output) {
+ return new Reference(name, arguments, output);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) return true;
+ if (!(o instanceof Reference)) return false;
+ Reference other = (Reference) o;
+ if (!Objects.equals(other.name, this.name)) return false;
+ if (!Objects.equals(other.arguments, this.arguments)) return false;
+ if (!Objects.equals(other.output, this.output)) return false;
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, arguments, output);
+ }
+
+ @Override
+ public String toString() {
+ return toString(new SerializationContext(), null, null);
+ }
+
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ StringBuilder b = new StringBuilder(name);
+ if (arguments != null && arguments.expressions().size() > 0)
+ b.append("(").append(arguments.expressions().stream()
+ .map(node -> node.toString(context, path, parent))
+ .collect(Collectors.joining(","))).append(")");
+ if (output != null)
+ b.append(".").append(output);
+ return b.toString();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
index 5f8daa69ecf..ee5952d9aea 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -82,8 +83,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable {
}
@Override
- public TensorType getType(String name) {
- Integer index = nameToIndex().get(name);
+ public TensorType getType(Reference reference) {
+ Integer index = nameToIndex().get(reference.toString());
if (index == null) return null;
return values[index].type();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 861f9565d66..4e046df11ca 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -1,9 +1,11 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Set;
@@ -14,7 +16,7 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public abstract class Context implements EvaluationContext {
+public abstract class Context implements EvaluationContext<Reference> {
/**
* Returns the value of a simple variable name.
@@ -24,6 +26,11 @@ public abstract class Context implements EvaluationContext {
*/
public abstract Value get(String name);
+ @Override
+ public TensorType getType(String reference) {
+ throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ }
+
/** Returns a variable as a tensor */
@Override
public Tensor getTensor(String name) { return get(name).asTensor(); }
@@ -46,6 +53,7 @@ public abstract class Context implements EvaluationContext {
* calculation to output several), or null to output the
* "main" (or only) value.
*/
+ // TODO: Remove/change to use reference?
public Value get(String name, Arguments arguments, String output) {
if (arguments != null && arguments.expressions().size() > 0)
name = name + "(" + arguments.expressions().stream().map(ExpressionNode::toString).collect(Collectors.joining(",")) + ")";
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
index 0625e8506cc..0004036da4b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
/**
@@ -68,7 +69,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext {
}
@Override
- public TensorType getType(String name) { return TensorType.empty; }
+ public TensorType getType(Reference reference) {
+ return TensorType.empty; // Double only
+ }
/** Perform a slow lookup by name */
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
index a81d0c89f8f..4ef24d60bba 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import java.util.Collections;
@@ -15,7 +16,7 @@ import java.util.Set;
*/
public class MapContext extends Context {
- private Map<String, Value> bindings = new HashMap<>();
+ private Map<String, Value> bindings = new HashMap<>(); // TODO: Change String to Reference
private boolean frozen = false;
@@ -42,8 +43,8 @@ public class MapContext extends Context {
/** Returns the type of the given value key, or null if it is not bound. */
@Override
- public TensorType getType(String key) {
- Value value = bindings.get(key);
+ public TensorType getType(Reference key) {
+ Value value = bindings.get(key.toString());
if (value == null) return null;
return value.type();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java
new file mode 100644
index 00000000000..2a42e2d92f7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java
@@ -0,0 +1,38 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A context which only contains type information.
+ *
+ * @author bratseth
+ */
+public class MapTypeContext implements TypeContext<Reference> {
+
+ private final Map<Reference, TensorType> featureTypes = new HashMap<>();
+
+ public void setType(Reference reference, TensorType type) {
+ featureTypes.put(reference, type);
+ }
+
+ @Override
+ public TensorType getType(String reference) {
+ throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ }
+
+ @Override
+ public TensorType getType(Reference reference) {
+ return featureTypes.get(reference);
+ }
+
+ /** Returns an unmodifiable map of the bindings in this */
+ public Map<Reference, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
deleted file mode 100644
index ff2088263d8..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * A context which only contains type information.
- *
- * @author bratseth
- */
-public class TypeMapContext implements TypeContext {
-
- private final Map<String, TensorType> featureTypes = new HashMap<>();
-
- public void setType(String name, TensorType type) {
- featureTypes.put(name, type);
- }
-
- @Override
- public TensorType getType(String name) {
- return featureTypes.get(name);
- }
-
- /** Returns an unmodifiable map of the bindings in this */
- public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
index 8ee4cdbf297..649c70122f1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -26,7 +27,7 @@ public class GBDTForestNode extends ExpressionNode {
}
@Override
- public final TensorType type(TypeContext context) { return TensorType.empty; }
+ public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; }
@Override
public final Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index aac635b2545..53a286f09f6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -51,7 +52,7 @@ public final class GBDTNode extends ExpressionNode {
public final double[] values() { return values; }
@Override
- public final TensorType type(TypeContext context) { return TensorType.empty; }
+ public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; }
@Override
public final Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
index fb9a7cb9ad7..d3a12d0f312 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
@@ -13,7 +13,7 @@ import java.util.List;
/**
* A set of argument expressions to a function or feature.
- * This is immutable.
+ * This is a value object.
*
* @author bratseth
*/
@@ -22,7 +22,11 @@ public final class Arguments implements Serializable {
private final ImmutableList<ExpressionNode> expressions;
public Arguments() {
- this(null);
+ this(ImmutableList.of());
+ }
+
+ public Arguments(ExpressionNode singleArgument) {
+ this(ImmutableList.of(singleArgument));
}
public Arguments(List<? extends ExpressionNode> expressions) {
@@ -38,9 +42,12 @@ public final class Arguments implements Serializable {
this.expressions = b.build();
}
- /** Returns an unmodifiable list of the expressions in this */
+ /** Returns an unmodifiable list of the expressions in this, never null */
public List<ExpressionNode> expressions() { return expressions; }
+ /** Returns the number of arguments in this */
+ public int size() { return expressions.size(); }
+
/** Evaluate all arguments in this */
public Value[] evaluate(Context context) {
Value[] values=new Value[expressions.size()];
@@ -62,8 +69,9 @@ public final class Arguments implements Serializable {
}
@Override
- public boolean equals(Object rhs) {
- return rhs instanceof Arguments && expressions.equals(((Arguments)rhs).expressions);
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ return other instanceof Arguments && expressions.equals(((Arguments)other).expressions);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index fc6428a4c33..49c49bed9bd 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -80,7 +81,7 @@ public final class ArithmeticNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
// Compute type using tensor types as arithmetic operators are supported on tensors
// and is correct also in the special case of doubles.
// As all our functions are type-commutative, we don't need to take operator precedence into account
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java
index 1d7d9b1ecda..cd4ddbcae55 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java
@@ -5,7 +5,6 @@ package com.yahoo.searchlib.rankingexpression.rule;
* A node which produces a boolean value when evaluated.
*
* @author bratseth
- * @since 5.1.21
*/
public abstract class BooleanNode extends CompositeNode {
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 7601c0e6180..eb328486045 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -49,7 +50,7 @@ public class ComparisonNode extends BooleanNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return TensorType.empty; // by definition
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
index 1ea8d03f0eb..3ddd7223349 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -49,7 +50,7 @@ public final class ConstantNode extends ExpressionNode {
}
@Override
- public TensorType type(TypeContext context) { return value.type(); }
+ public TensorType type(TypeContext<Reference> context) { return value.type(); }
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
index fd9fab99db8..47c2897e4a4 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -50,7 +51,7 @@ public final class EmbracedNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
index 477f4db4981..6bb163590de 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -48,7 +49,7 @@ public abstract class ExpressionNode implements Serializable {
* @param context the variable type bindings to use for this evaluation
* @throws IllegalArgumentException if there are variables which are not bound in the given map
*/
- public abstract TensorType type(TypeContext context);
+ public abstract TensorType type(TypeContext<Reference> context);
/**
* Returns the value of evaluating this expression over the given context.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index 79515229019..1da2210a39c 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -67,7 +68,7 @@ public final class FunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
if (arguments.expressions().size() == 0)
return TensorType.empty;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
new file mode 100644
index 00000000000..ed1e2838717
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
@@ -0,0 +1,74 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The context of a function invocation.
+ *
+ * @author bratseth
+ */
+public class FunctionReferenceContext {
+
+ /** Expression functions indexed by name */
+ private final ImmutableMap<String, ExpressionFunction> functions;
+
+ /** Mapping from argument names to the expressions they resolve to */
+ // TODO: Make private
+ public final Map<String, String> bindings = new HashMap<>();
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext() {
+ this(Collections.emptyList());
+ }
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext(Collection<ExpressionFunction> functions) {
+ this(toMap(functions), Collections.emptyMap());
+ }
+
+ public FunctionReferenceContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) {
+ this(toMap(functions), bindings);
+ }
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext(Map<String, ExpressionFunction> functions) {
+ this(functions.values());
+ }
+
+ /** Create a context for a single serialization task */
+ public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) {
+ this.functions = ImmutableMap.copyOf(functions);
+ if (bindings != null)
+ this.bindings.putAll(bindings);
+ }
+
+ private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) {
+ ImmutableMap.Builder<String,ExpressionFunction> mapBuilder = new ImmutableMap.Builder<>();
+ for (ExpressionFunction function : list)
+ mapBuilder.put(function.getName(), function);
+ return mapBuilder.build();
+ }
+
+ /**
+ * Returns a function or null if it isn't defined in this context
+ */
+ public ExpressionFunction getFunction(String name) { return functions.get(name); }
+
+ protected Map<String, ExpressionFunction> functions() { return functions; }
+
+ /** Returns the resolution of an argument, or null if it isn't defined in this context */
+ public String getBinding(String name) { return bindings.get(name); }
+
+ /** Returns a new context with the bindings replaced by the given bindings */
+ public FunctionReferenceContext withBindings(Map<String, String> bindings) {
+ return new FunctionReferenceContext(this.functions, bindings);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index e42884ecc05..c87eb0ace39 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -48,7 +49,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) { return type; }
+ public TensorType type(TypeContext<Reference> context) { return type; }
/** Evaluate this in a context which must have the arguments bound */
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index 66b250736e8..ee4edac4941 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -75,7 +76,7 @@ public final class IfNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
TensorType trueType = trueExpression.type(context);
TensorType falseType = falseExpression.type(context);
return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() ->
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index da946228291..61086f8182a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -57,7 +58,7 @@ public class LambdaFunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return TensorType.empty; // by definition - no nested lambdas
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index f55ed59b65c..f1adf331630 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -14,6 +15,7 @@ import java.util.Deque;
*
* @author Simon Thoresen
*/
+// TODO: This is achieved by ReferenceNode in almost all cases - remove this
public final class NameNode extends ExpressionNode {
private final String name;
@@ -32,7 +34,7 @@ public final class NameNode extends ExpressionNode {
}
@Override
- public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); }
+ public TensorType type(TypeContext<Reference> context) { throw new RuntimeException("Named nodes can not have a type"); }
@Override
public Value evaluate(Context context) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
index 9cbe5f98c72..fcc03dc4862 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -38,7 +39,7 @@ public class NegativeNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
index e7041600635..a539f496ff5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -38,7 +39,7 @@ public class NotNode extends BooleanNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return value.type(context);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 05a6773c5cb..78f53b1593d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
@@ -13,114 +14,102 @@ import java.util.Deque;
import java.util.List;
/**
- * A node referring either to a value in the context or to another named ranking expression.
+ * A node referring either to a value in the context or to a named ranking expression (function aka macro).
*
* @author simon
* @author bratseth
*/
public final class ReferenceNode extends CompositeNode {
- private final String name, output;
-
- private final Arguments arguments;
+ private final Reference reference;
+ /* Creates a node with a simple identifier reference */
public ReferenceNode(String name) {
this(name, null, null);
}
public ReferenceNode(String name, List<? extends ExpressionNode> arguments, String output) {
- this.name = name;
- this.arguments = arguments != null ? new Arguments(arguments) : new Arguments();
- this.output = output;
+ this.reference = new Reference(name,
+ arguments != null ? new Arguments(arguments) : new Arguments(),
+ output);
+ }
+
+ public ReferenceNode(Reference reference) {
+ this.reference = reference;
}
public String getName() {
- return name;
+ return reference.name();
}
/** Returns the arguments, never null */
- public Arguments getArguments() { return arguments; }
+ public Arguments getArguments() { return reference.arguments(); }
/** Returns a copy of this where the arguments are replaced by the given arguments */
public ReferenceNode setArguments(List<ExpressionNode> arguments) {
- return new ReferenceNode(name, arguments, output);
+ return new ReferenceNode(reference.withArguments(new Arguments(arguments)));
}
/** Returns the specific output this references, or null if none specified */
- public String getOutput() { return output; }
+ public String getOutput() { return reference.output(); }
/** Returns a copy of this node with a modified output */
public ReferenceNode setOutput(String output) {
- return new ReferenceNode(name, arguments.expressions(), output);
+ return new ReferenceNode(reference.withOutput(output));
}
/** Returns an empty list as this has no children */
@Override
- public List<ExpressionNode> children() { return arguments.expressions(); }
+ public List<ExpressionNode> children() { return reference.arguments().expressions(); }
@Override
public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- if (path == null)
- path = new ArrayDeque<>();
- String myName = this.name;
- String myOutput = this.output;
- List<ExpressionNode> myArguments = this.arguments.expressions();
-
- String resolvedArgument = context.getBinding(myName);
- if (resolvedArgument != null && this.arguments.expressions().size() == 0 && myOutput == null) {
- // Replace this whole node with the value of the argument value that it maps to
- myName = resolvedArgument;
- myArguments = null;
- myOutput = null;
- } else if (context.getFunction(myName) != null) {
- // Replace by the referenced expression
- ExpressionFunction function = context.getFunction(myName);
- if (function != null && myArguments != null && function.arguments().size() == myArguments.size() && myOutput == null) {
- String myPath = name + this.arguments.expressions();
- if (path.contains(myPath)) {
- throw new IllegalStateException("Cycle in ranking expression function: " + path);
- }
- path.addLast(myPath);
- ExpressionFunction.Instance instance = function.expand(context, myArguments, path);
- path.removeLast();
- context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
- myName = "rankingExpression(" + instance.getName() + ")";
- myArguments = null;
- myOutput = null;
- }
+ if (reference.isIdentifier() && context.getBinding(getName()) != null) {
+ // a bound identifier: replace by the value it is bound to
+ return context.getBinding(getName());
}
- // Always print the same way, the magic is already done.
- StringBuilder ret = new StringBuilder(myName);
- if (myArguments != null && myArguments.size() > 0) {
- ret.append("(");
- for (int i = 0; i < myArguments.size(); ++i) {
- ret.append(myArguments.get(i).toString(context, path, this));
- if (i < myArguments.size() - 1) {
- ret.append(",");
- }
- }
- ret.append(")");
+
+ ExpressionFunction function = context.getFunction(getName());
+ if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) {
+ // a function reference: replace by the referenced function wrapped in rankingExpression
+ if (path == null)
+ path = new ArrayDeque<>();
+ String myPath = getName() + getArguments().expressions();
+ if (path.contains(myPath))
+ throw new IllegalStateException("Cycle in ranking expression function: " + path);
+ path.addLast(myPath);
+ ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path);
+ path.removeLast();
+ context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
+ return "rankingExpression(" + instance.getName() + ")";
}
- ret.append(myOutput != null ? "." + myOutput : "");
- return ret.toString();
+
+ // not resolved in this context: output as-is
+ return reference.toString(context, path, parent);
}
+ /** Returns the reference of this node */
+ public Reference reference() { return reference; }
+
@Override
- public TensorType type(TypeContext context) {
- // Don't support outputs of different type, for simplicity
- return context.getType(toString());
+ public TensorType type(TypeContext<Reference> context) {
+ TensorType type = context.getType(reference);
+ if (type == null)
+ throw new IllegalArgumentException("Unknown feature '" + toString() + "'");
+ return type;
}
@Override
public Value evaluate(Context context) {
- if (arguments.expressions().isEmpty() && output == null)
- return context.get(name);
- return context.get(name, arguments, output);
+ // TODO: Context should accept a Reference instead.
+ if (reference.isIdentifier())
+ return context.get(reference.name());
+ return context.get(getName(), getArguments(), getOutput());
}
@Override
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
- return new ReferenceNode(name, newChildren, output);
+ return setArguments(newChildren);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index ba765d07094..796c13a8669 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -16,17 +16,11 @@ import java.util.Map;
*
* @author bratseth
*/
-public class SerializationContext {
+public class SerializationContext extends FunctionReferenceContext {
- /** Expression functions indexed by name */
- private final ImmutableMap<String, ExpressionFunction> functions;
-
- /** A cache of already serialized expressions indexed by name */
+ /** Serialized form of functions indexed by name */
private final Map<String, String> serializedFunctions;
- /** Mapping from argument names to the expressions they resolve to */
- public final Map<String, String> bindings = new HashMap<>();
-
/** Create a context for a single serialization task */
public SerializationContext() {
this(Collections.emptyList());
@@ -77,17 +71,10 @@ public class SerializationContext {
*/
public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings,
Map<String, String> serializedFunctions) {
- this.functions = functions;
+ super(functions, bindings);
this.serializedFunctions = serializedFunctions;
- if (bindings != null)
- this.bindings.putAll(bindings);
}
- /**
- * Returns a function or null if it isn't defined in this context
- */
- public ExpressionFunction getFunction(String name) { return functions.get(name); }
-
/** Adds the serialization of a function */
public void addFunctionSerialization(String name, String expressionString) {
serializedFunctions.put(name, expressionString);
@@ -98,17 +85,9 @@ public class SerializationContext {
return serializedFunctions.get(name);
}
- /**
- * Returns the resolution of an argument, or null if it isn't defined in this context
- */
- public String getBinding(String name) { return bindings.get(name); }
-
- /**
- * Returns a new context which shares the functions and serialized function map with this but has different
- * arguments.
- */
- public SerializationContext createBinding(Map<String, String> arguments) {
- return new SerializationContext(this.functions, arguments, this.serializedFunctions);
+ @Override
+ public SerializationContext withBindings(Map<String, String> bindings) {
+ return new SerializationContext(functions().values(), bindings, this.serializedFunctions);
}
public Map<String, String> serializedFunctions() { return serializedFunctions; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index a7b82f4753f..cb31219579a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -60,7 +61,7 @@ public class SetMembershipNode extends BooleanNode {
}
@Override
- public TensorType type(TypeContext context) {
+ public TensorType type(TypeContext<Reference> context) {
return TensorType.empty;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ec6af4bb413..6c9b6bb4a98 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -64,7 +65,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public TensorType type(TypeContext context) { return function.type(context); }
+ public TensorType type(TypeContext<Reference> context) { return function.type(context); }
@Override
public Value evaluate(Context context) {
@@ -111,12 +112,13 @@ public class TensorFunctionNode extends CompositeNode {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) {
- return expression.type(context);
+ @SuppressWarnings("unchecked") // Generics awkwardness
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ return expression.type((TypeContext<Reference>)context);
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
return expression.evaluate((Context)context).asTensor();
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index e9030cf5852..f2122bb5da9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -378,8 +378,13 @@ public class EvaluationTestCase {
private static class StructuredTestContext extends MapContext {
@Override
+ public Value get(String feature) {
+ throw new RuntimeException("Called simple get for feature " + feature);
+ }
+
+ @Override
public Value get(String name, Arguments arguments, String output) {
- if (!name.equals("average")) {
+ if ( ! name.equals("average")) {
throw new IllegalArgumentException("Unknown operation '" + name + "'");
}
if (arguments.expressions().size() != 2) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
index c882c887c8d..a08d510eec4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
@@ -3,6 +3,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -18,12 +19,17 @@ public class TypeResolutionTestCase {
@Test
public void testTypeResolution() {
- TypeMapContext context = new TypeMapContext();
- context.setType("query(x1)", TensorType.fromSpec("tensor(x[])"));
- context.setType("query(x2)", TensorType.fromSpec("tensor(x[10])"));
- context.setType("query(y1)", TensorType.fromSpec("tensor(y[])"));
- context.setType("query(xy1)", TensorType.fromSpec("tensor(x[10],y[])"));
- context.setType("query(xy2)", TensorType.fromSpec("tensor(x[],y[10])"));
+ MapTypeContext context = new MapTypeContext();
+ context.setType(Reference.simple("query", "x1"),
+ TensorType.fromSpec("tensor(x[])"));
+ context.setType(Reference.simple("query", "x2"),
+ TensorType.fromSpec("tensor(x[10])"));
+ context.setType(Reference.simple("query", "y1"),
+ TensorType.fromSpec("tensor(y[])"));
+ context.setType(Reference.simple("query", "xy1"),
+ TensorType.fromSpec("tensor(x[10],y[])"));
+ context.setType(Reference.simple("query", "xy2"),
+ TensorType.fromSpec("tensor(x[],y[10])"));
assertType("tensor(x[])", "query(x1)", context);
assertType("tensor(x[])", "if (1>0, query(x1), query(x2))", context);
@@ -31,7 +37,7 @@ public class TypeResolutionTestCase {
assertIncompatibleType("if (1>0, query(x1), query(y1))", context);
}
- private void assertType(String type, String expression, TypeContext context) {
+ private void assertType(String type, String expression, TypeContext<Reference> context) {
try {
assertEquals(TensorType.fromSpec(type), new RankingExpression(expression).type(context));
}
@@ -40,7 +46,7 @@ public class TypeResolutionTestCase {
}
}
- private void assertIncompatibleType(String expression, TypeContext context) {
+ private void assertIncompatibleType(String expression, TypeContext<Reference> context) {
try {
new RankingExpression(expression).type(context);
fail("Expected type incompatibility exception");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java
index 867331e99ce..303135888d8 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java
@@ -9,13 +9,13 @@ import java.util.Collections;
import static org.junit.Assert.*;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public class ArgumentsTestCase {
@Test
public void requireThatAccessorsWork() {
- Arguments args = new Arguments(null);
+ Arguments args = new Arguments();
assertTrue(args.expressions().isEmpty());
args = new Arguments(Collections.<ExpressionNode>emptyList());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 3fb94f1251b..8a969180113 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor;
* @author bratseth
*/
@Beta
-public interface EvaluationContext extends TypeContext {
+public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> {
/** Returns the tensor bound to this name, or null if none */
Tensor getTensor(String name);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index 9fe6b7d053f..b9394da31e3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -11,17 +11,20 @@ import java.util.HashMap;
* @author bratseth
*/
@Beta
-public class MapEvaluationContext implements EvaluationContext {
+public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> {
private final java.util.Map<String, Tensor> bindings = new HashMap<>();
- static MapEvaluationContext empty() { return new MapEvaluationContext(); }
-
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
@Override
public TensorType getType(String name) {
- Tensor tensor = bindings.get(name);
+ return getType(new Name(name));
+ }
+
+ @Override
+ public TensorType getType(Name name) {
+ Tensor tensor = bindings.get(name.toString());
if (tensor == null) return null;
return tensor.type();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
index 760a225efdf..ff2e6318b37 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType;
*
* @author bratseth
*/
-public interface TypeContext {
+public interface TypeContext<NAMETYPE extends TypeContext.Name> {
/**
* Returns the type of the tensor with this name.
@@ -16,6 +16,39 @@ public interface TypeContext {
* @return returns the type of the tensor which will be returned by calling getTensor(name)
* or null if getTensor will return null.
*/
+ TensorType getType(NAMETYPE name);
+
+ /**
+ * Returns the type of the tensor with this name by converting from a string name.
+ *
+ * @return returns the type of the tensor which will be returned by calling getTensor(name)
+ * or null if getTensor will return null.
+ */
TensorType getType(String name);
+ /** A name which is just a string. Names are value objects. */
+ class Name {
+
+ private final String name;
+
+ public Name(String name) {
+ this.name = name;
+ }
+
+ @Override
+ public String toString() { return name; }
+
+ @Override
+ public int hashCode() { return name.hashCode(); }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ if ( ! (other instanceof Name)) return false;
+ return ((Name)other).name.equals(this.name);
+ }
+
+ }
+
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 34beb465d4c..acb2363cba4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -44,7 +44,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
TensorType givenType = context.getType(name);
if (givenType == null) return null;
verifyType(givenType);
@@ -52,7 +52,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = context.getTensor(name);
if (tensor == null) return null;
verifyType(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 2109b730e1a..bfc0938abcc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -18,10 +18,14 @@ public abstract class CompositeTensorFunction extends TensorFunction {
/** Finds the type this produces by first converting it to a primitive function */
@Override
- public final TensorType type(TypeContext context) { return toPrimitive().type(context); }
+ public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ return toPrimitive().type(context);
+ }
/** Evaluates this by first converting it to a primitive function */
@Override
- public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
+ public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ return toPrimitive().evaluate(context);
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index c77ed1c0526..a073053bec8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -60,7 +60,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return type(argumentA.type(context), argumentB.type(context));
}
@@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
a = ensureIndexedDimension(dimension, a);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 50b479da168..a43de297b9a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) { return constant.type(); }
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); }
@Override
- public Tensor evaluate(EvaluationContext context) { return constant; }
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
@Override
public String toString(ToStringContext context) { return constant.toString(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index e70d1de3db7..edfa8253eb9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) { return type; }
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
for (int i = 0; i < indexes.size(); i++) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 7812c985091..17e1c103ea3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 53504868ff2..4a338e5501e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return argument.type(context);
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor argument = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(argument.type());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 76a938b9fe2..e045effbe7e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -101,11 +101,12 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context));
}
private TensorType type(TensorType argumentType) {
+ if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
@@ -114,7 +115,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index de3d2be265a..af4492ca1e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(TypeContext context) {
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context));
}
@@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public Tensor evaluate(EvaluationContext context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = argument.evaluate(context);
TensorType renamedType = type(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 78ab09c7820..e805e9d87bb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -43,14 +43,14 @@ public abstract class TensorFunction {
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract Tensor evaluate(EvaluationContext context);
+ public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context);
/**
* Returns the type of the tensor this produces given the input types in the context
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract TensorType type(TypeContext context);
+ public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context);
/** Evaluate with no context */
public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); }
@@ -58,7 +58,7 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
*
- * @param context a context which must be passed to all nexted functions when requesting the string value
+ * @param context a context which must be passed to all nested functions when requesting the string value
*/
public abstract String toString(ToStringContext context);