diff options
author | Jon Bratseth <bratseth@oath.com> | 2019-06-12 18:49:11 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-12 18:49:11 +0200 |
commit | cee1c3a3804d5d3c25407b3c4ac64228e9d194e3 (patch) | |
tree | dd62faee06cd29d5820f9bc33a488be55a6ceef8 /searchlib/src/main | |
parent | 5311e389929c05707856697e73db61b6acee3c5a (diff) |
Require constant() for large constants and fix a type resolving bug (#9769)
* Require constant() for large constants and fix a type resolving bug
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
* Remove noise
Diffstat (limited to 'searchlib/src/main')
6 files changed, 121 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index c4f3a75f2f8..2aedec2109b 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.transform.TensorMaxMinTransformer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Join; @@ -67,6 +68,11 @@ public final class FunctionNode extends CompositeNode { @Override public TensorType type(TypeContext<Reference> context) { + // Check if this node should be interpreted as tensor reduce, as this impacts the type + ExpressionNode thisTransformed = TensorMaxMinTransformer.transformFunctionNode(this, context); + if (thisTransformed != this) + return thisTransformed.type(context); + if (arguments.expressions().size() == 0) return TensorType.empty; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 28dc623be72..92c6d6f8638 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -85,7 +85,9 @@ public final class IfNode extends CompositeNode { return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() -> new IllegalArgumentException("An if expression must produce compatible types in both " + "alternatives, but the 'true' type is " + trueType + " while the " + - "'false' type is " + falseType) + "'false' type is " + falseType + + "\n'true' branch: " + trueExpression + + "\n'false' branch: " + falseExpression) ); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index eb8d2229a6d..e15ce158e83 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -95,7 +95,13 @@ public final class ReferenceNode extends CompositeNode { @Override public TensorType type(TypeContext<Reference> context) { - TensorType type = context.getType(reference); + TensorType type = null; + try { + type = context.getType(reference); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException(reference + " is invalid", e); + } if (type == null) throw new IllegalArgumentException("Unknown feature '" + toString() + "'"); return type; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java index 22d314bcb28..31567ba120b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java @@ -10,7 +10,7 @@ import java.util.List; /** * Superclass of expression transformers. The scope (lifetime) of a transformer instance is a single compilation - * of alle the expressions in one rank profile. + * of all the expressions in one rank profile. * * @author bratseth */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java new file mode 100644 index 00000000000..979c5b0f88c --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -0,0 +1,93 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.transform; + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Reduce; + +import java.util.Optional; + +/** + * Transforms min(tensor,dim) and max(tensor,dim) to + * reduce(tensor,min/max,dim). This is necessary as the backend does + * not recognize these forms of min and max. + * + * @author lesters + */ +public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends ExpressionTransformer<CONTEXT> { + + @Override + public ExpressionNode transform(ExpressionNode node, CONTEXT context) { + if (node instanceof CompositeNode) { + node = transformChildren((CompositeNode) node, context); + } + if (node instanceof FunctionNode) { + node = transformFunctionNode((FunctionNode) node, context.types()); + } + return node; + } + + public static ExpressionNode transformFunctionNode(FunctionNode node, TypeContext<Reference> context) { + switch (node.getFunction()) { + case min: + case max: + return transformMaxAndMinFunctionNode(node, context); + } + return node; + } + + /** + * Transforms max and min functions if the first + * argument returns a tensor type and the second argument is a valid + * dimension in the tensor. + */ + private static ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, TypeContext<Reference> context) { + if (node.children().size() != 2) { + return node; + } + ExpressionNode arg1 = node.children().get(0); + Optional<String> dimension = dimensionName(node.children().get(1)); + if (dimension.isPresent()) { + TensorType type = arg1.type(context); + if (type.dimension(dimension.get()).isPresent()) { + return replaceMaxAndMinFunction(node); + } + } + return node; + } + + private static Optional<String> dimensionName(ExpressionNode node) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (reference.isIdentifier()) + return Optional.of(reference.name()); + else + return Optional.empty(); + } + else if (node instanceof NameNode) { + return Optional.of(((NameNode)node).getValue()); + } + else { + return Optional.empty(); + } + } + + private static ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { + ExpressionNode arg1 = node.children().get(0); + ExpressionNode arg2 = node.children().get(1); + + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); + Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); + String dimension = ((ReferenceNode) arg2).getName(); + + return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java index 7485ce69f98..0113a650277 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java @@ -1,7 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Map; @@ -13,11 +15,19 @@ import java.util.Map; public class TransformContext { private final Map<String, Value> constants; + private final TypeContext<Reference> types; - public TransformContext(Map<String, Value> constants) { + public TransformContext(Map<String, Value> constants, TypeContext<Reference> types) { this.constants = constants; + this.types = types; } public Map<String, Value> constants() { return constants; } + /** + * Returns the types known in this context. We may have type information for references + * for which no value is available + */ + public TypeContext<Reference> types() { return types; } + } |