diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-27 15:58:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-27 15:58:06 +0200 |
commit | 77bb8f5117b7a0f78b2dc99a3937430339e4291d (patch) | |
tree | 9037b54f17e3175a8d11e1b43b55b71887f867a4 /searchlib/src/main | |
parent | f4203c3cc571722f08ee65047437c1290ed63f69 (diff) |
Support index generating expressions in tensor value functions
Diffstat (limited to 'searchlib/src/main')
4 files changed, 45 insertions, 28 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java index 63cea371d14..41b01c9a2cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer; @@ -58,11 +59,12 @@ public class TensorOptimizer extends Optimizer { * The ReduceJoin class determines whether or not the arguments are * compatible with the optimization. */ + @SuppressWarnings("unchecked") private ExpressionNode optimizeReduceJoin(ExpressionNode node) { if ( ! (node instanceof TensorFunctionNode)) { return node; } - TensorFunction function = ((TensorFunctionNode) node).function(); + TensorFunction<Reference> function = ((TensorFunctionNode) node).function(); if ( ! (function instanceof Reduce)) { return node; } @@ -74,10 +76,10 @@ public class TensorOptimizer extends Optimizer { if ( ! (child instanceof TensorFunctionNode)) { return node; } - TensorFunction argument = ((TensorFunctionNode) child).function(); + TensorFunction<Reference> argument = ((TensorFunctionNode) child).function(); if (argument instanceof Join) { report.incMetric("Replaced reduce->join", 1); - return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument)); + return new TensorFunctionNode(new ReduceJoin<>((Reduce<Reference>)function, (Join<Reference>)argument)); } return node; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 18f1fa8a78f..cec8837abcd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -33,14 +33,14 @@ import java.util.stream.Collectors; @Beta public class TensorFunctionNode extends CompositeNode { - private final TensorFunction function; + private final TensorFunction<Reference> function; - public TensorFunctionNode(TensorFunction function) { + public TensorFunctionNode(TensorFunction<Reference> function) { this.function = function; } /** Returns the tensor function wrapped by this */ - public TensorFunction function() { return function; } + public TensorFunction<Reference> function() { return function; } @Override public List<ExpressionNode> children() { @@ -49,7 +49,7 @@ public class TensorFunctionNode extends CompositeNode { .collect(Collectors.toList()); } - private ExpressionNode toExpressionNode(TensorFunction f) { + private ExpressionNode toExpressionNode(TensorFunction<Reference> f) { if (f instanceof ExpressionTensorFunction) return ((ExpressionTensorFunction)f).expression; else @@ -58,9 +58,9 @@ public class TensorFunctionNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { - List<TensorFunction> wrappedChildren = children.stream() - .map(ExpressionTensorFunction::new) - .collect(Collectors.toList()); + List<TensorFunction<Reference>> wrappedChildren = children.stream() + .map(ExpressionTensorFunction::new) + .collect(Collectors.toList()); return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @@ -132,7 +132,7 @@ public class TensorFunctionNode extends CompositeNode { * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ - public static class ExpressionTensorFunction extends PrimitiveTensorFunction { + public static class ExpressionTensorFunction extends PrimitiveTensorFunction<Reference> { /** An expression which produces a tensor */ private final ExpressionNode expression; @@ -142,7 +142,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> arguments() { + public List<TensorFunction<Reference>> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(ExpressionTensorFunction::new) @@ -152,7 +152,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() .map(arg -> ((ExpressionTensorFunction)arg).expression) @@ -161,16 +161,15 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<Reference> toPrimitive() { return this; } @Override - @SuppressWarnings("unchecked") // Generics awkwardness - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { - return expression.type((TypeContext<Reference>)context); + public TensorType type(TypeContext<Reference> context) { + return expression.type(context); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<Reference> context) { return expression.evaluate((Context)context).asTensor(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 9a38b5efc1f..9bed4a4ea7c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -87,7 +87,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); - return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); + return new TensorFunctionNode(new Reduce<>(expression, aggregator, dimension)); } } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index beab722a1eb..c7870182939 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -887,30 +887,46 @@ void labelAndDimension(TensorAddress.Builder addressBuilder) : void labelAndDimensionValues(List addressValues) : { - String dimension, label; + String dimension; + Value.DimensionValue dimensionValue; } { - dimension = identifier() <COLON> label = tag() - { addressValues.add(new Value.DimensionValue(dimension, label)); } + dimension = identifier() <COLON> dimensionValue = dimensionValue(Optional.of(dimension)) + { addressValues.add(dimensionValue); } } /** A tensor address (possibly on short form) represented as a list because the tensor type is not available */ List valueAddress() : { List dimensionValues = new ArrayList(); - String label; + ExpressionNode valueExpression; + Value.DimensionValue dimensionValue; } { ( - ( <LSQUARE> ( <INTEGER> { dimensionValues.add(new Value.DimensionValue(token.image)); } ) <RSQUARE> ) - | - LOOKAHEAD(3) ( <LCURLY> label = tag() { dimensionValues.add(new Value.DimensionValue(label)); } <RCURLY> ) + ( <LSQUARE> ( valueExpression = expression() { dimensionValues.add(new Value.DimensionValue(TensorFunctionNode.wrapScalar(valueExpression))); } ) <RSQUARE> ) | - ( <LCURLY> - ( labelAndDimensionValues(dimensionValues))* + LOOKAHEAD(3) ( <LCURLY> + ( labelAndDimensionValues(dimensionValues))+ ( <COMMA> labelAndDimensionValues(dimensionValues))* <RCURLY> ) + | + ( <LCURLY> dimensionValue = dimensionValue(Optional.empty()) { dimensionValues.add(dimensionValue); } <RCURLY> ) ) { return dimensionValues;} +} + +Value.DimensionValue dimensionValue(Optional dimensionName) : +{ + ExpressionNode value; +} +{ + value = expression() + { + if (value instanceof ReferenceNode && ((ReferenceNode)value).reference().isIdentifier()) + return new Value.DimensionValue(dimensionName, ((ReferenceNode)value).reference().name()); + else + return new Value.DimensionValue(dimensionName, TensorFunctionNode.wrapScalar(value)); + } }
\ No newline at end of file |