diff options
Diffstat (limited to 'searchlib/src/main/java/com/yahoo')
3 files changed, 20 insertions, 19 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java index 63cea371d14..41b01c9a2cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer; @@ -58,11 +59,12 @@ public class TensorOptimizer extends Optimizer { * The ReduceJoin class determines whether or not the arguments are * compatible with the optimization. */ + @SuppressWarnings("unchecked") private ExpressionNode optimizeReduceJoin(ExpressionNode node) { if ( ! (node instanceof TensorFunctionNode)) { return node; } - TensorFunction function = ((TensorFunctionNode) node).function(); + TensorFunction<Reference> function = ((TensorFunctionNode) node).function(); if ( ! (function instanceof Reduce)) { return node; } @@ -74,10 +76,10 @@ public class TensorOptimizer extends Optimizer { if ( ! (child instanceof TensorFunctionNode)) { return node; } - TensorFunction argument = ((TensorFunctionNode) child).function(); + TensorFunction<Reference> argument = ((TensorFunctionNode) child).function(); if (argument instanceof Join) { report.incMetric("Replaced reduce->join", 1); - return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument)); + return new TensorFunctionNode(new ReduceJoin<>((Reduce<Reference>)function, (Join<Reference>)argument)); } return node; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 18f1fa8a78f..cec8837abcd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -33,14 +33,14 @@ import java.util.stream.Collectors; @Beta public class TensorFunctionNode extends CompositeNode { - private final TensorFunction function; + private final TensorFunction<Reference> function; - public TensorFunctionNode(TensorFunction function) { + public TensorFunctionNode(TensorFunction<Reference> function) { this.function = function; } /** Returns the tensor function wrapped by this */ - public TensorFunction function() { return function; } + public TensorFunction<Reference> function() { return function; } @Override public List<ExpressionNode> children() { @@ -49,7 +49,7 @@ public class TensorFunctionNode extends CompositeNode { .collect(Collectors.toList()); } - private ExpressionNode toExpressionNode(TensorFunction f) { + private ExpressionNode toExpressionNode(TensorFunction<Reference> f) { if (f instanceof ExpressionTensorFunction) return ((ExpressionTensorFunction)f).expression; else @@ -58,9 +58,9 @@ public class TensorFunctionNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { - List<TensorFunction> wrappedChildren = children.stream() - .map(ExpressionTensorFunction::new) - .collect(Collectors.toList()); + List<TensorFunction<Reference>> wrappedChildren = children.stream() + .map(ExpressionTensorFunction::new) + .collect(Collectors.toList()); return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @@ -132,7 +132,7 @@ public class TensorFunctionNode extends CompositeNode { * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ - public static class ExpressionTensorFunction extends PrimitiveTensorFunction { + public static class ExpressionTensorFunction extends PrimitiveTensorFunction<Reference> { /** An expression which produces a tensor */ private final ExpressionNode expression; @@ -142,7 +142,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> arguments() { + public List<TensorFunction<Reference>> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(ExpressionTensorFunction::new) @@ -152,7 +152,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() .map(arg -> ((ExpressionTensorFunction)arg).expression) @@ -161,16 +161,15 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<Reference> toPrimitive() { return this; } @Override - @SuppressWarnings("unchecked") // Generics awkwardness - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { - return expression.type((TypeContext<Reference>)context); + public TensorType type(TypeContext<Reference> context) { + return expression.type(context); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<Reference> context) { return expression.evaluate((Context)context).asTensor(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 9a38b5efc1f..9bed4a4ea7c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -87,7 +87,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); - return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); + return new TensorFunctionNode(new Reduce<>(expression, aggregator, dimension)); } } |