aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java29
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java2
3 files changed, 20 insertions, 19 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
index 63cea371d14..41b01c9a2cb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
@@ -58,11 +59,12 @@ public class TensorOptimizer extends Optimizer {
* The ReduceJoin class determines whether or not the arguments are
* compatible with the optimization.
*/
+ @SuppressWarnings("unchecked")
private ExpressionNode optimizeReduceJoin(ExpressionNode node) {
if ( ! (node instanceof TensorFunctionNode)) {
return node;
}
- TensorFunction function = ((TensorFunctionNode) node).function();
+ TensorFunction<Reference> function = ((TensorFunctionNode) node).function();
if ( ! (function instanceof Reduce)) {
return node;
}
@@ -74,10 +76,10 @@ public class TensorOptimizer extends Optimizer {
if ( ! (child instanceof TensorFunctionNode)) {
return node;
}
- TensorFunction argument = ((TensorFunctionNode) child).function();
+ TensorFunction<Reference> argument = ((TensorFunctionNode) child).function();
if (argument instanceof Join) {
report.incMetric("Replaced reduce->join", 1);
- return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument));
+ return new TensorFunctionNode(new ReduceJoin<>((Reduce<Reference>)function, (Join<Reference>)argument));
}
return node;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 18f1fa8a78f..cec8837abcd 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -33,14 +33,14 @@ import java.util.stream.Collectors;
@Beta
public class TensorFunctionNode extends CompositeNode {
- private final TensorFunction function;
+ private final TensorFunction<Reference> function;
- public TensorFunctionNode(TensorFunction function) {
+ public TensorFunctionNode(TensorFunction<Reference> function) {
this.function = function;
}
/** Returns the tensor function wrapped by this */
- public TensorFunction function() { return function; }
+ public TensorFunction<Reference> function() { return function; }
@Override
public List<ExpressionNode> children() {
@@ -49,7 +49,7 @@ public class TensorFunctionNode extends CompositeNode {
.collect(Collectors.toList());
}
- private ExpressionNode toExpressionNode(TensorFunction f) {
+ private ExpressionNode toExpressionNode(TensorFunction<Reference> f) {
if (f instanceof ExpressionTensorFunction)
return ((ExpressionTensorFunction)f).expression;
else
@@ -58,9 +58,9 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
- List<TensorFunction> wrappedChildren = children.stream()
- .map(ExpressionTensorFunction::new)
- .collect(Collectors.toList());
+ List<TensorFunction<Reference>> wrappedChildren = children.stream()
+ .map(ExpressionTensorFunction::new)
+ .collect(Collectors.toList());
return new TensorFunctionNode(function.withArguments(wrappedChildren));
}
@@ -132,7 +132,7 @@ public class TensorFunctionNode extends CompositeNode {
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
- public static class ExpressionTensorFunction extends PrimitiveTensorFunction {
+ public static class ExpressionTensorFunction extends PrimitiveTensorFunction<Reference> {
/** An expression which produces a tensor */
private final ExpressionNode expression;
@@ -142,7 +142,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public List<TensorFunction> arguments() {
+ public List<TensorFunction<Reference>> arguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(ExpressionTensorFunction::new)
@@ -152,7 +152,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> arguments) {
if (arguments.size() == 0) return this;
List<ExpressionNode> unwrappedChildren = arguments.stream()
.map(arg -> ((ExpressionTensorFunction)arg).expression)
@@ -161,16 +161,15 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<Reference> toPrimitive() { return this; }
@Override
- @SuppressWarnings("unchecked") // Generics awkwardness
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
- return expression.type((TypeContext<Reference>)context);
+ public TensorType type(TypeContext<Reference> context) {
+ return expression.type(context);
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<Reference> context) {
return expression.evaluate((Context)context).asTensor();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
index 9a38b5efc1f..9bed4a4ea7c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
@@ -87,7 +87,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
- return new TensorFunctionNode(new Reduce(expression, aggregator, dimension));
+ return new TensorFunctionNode(new Reduce<>(expression, aggregator, dimension));
}
}