summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java15
1 files changed, 8 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
index 8e49ce15265..b7a8a4a4e43 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -72,14 +73,14 @@ public class Reduce extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(1)) return null;
- TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
if (preOperator != null) {
- inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator);
+ inputFunction = new com.yahoo.tensor.functions.Map<>(inputFunction, preOperator);
}
- TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions);
+ TensorFunction<Reference> output = new com.yahoo.tensor.functions.Reduce<>(inputFunction, aggregator, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
@@ -88,12 +89,12 @@ public class Reduce extends IntermediateOperation {
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
+ Generate<Reference> generatedFunction = new Generate<>(generatedType,
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply());
}
if (postOperator != null) {
- output = new com.yahoo.tensor.functions.Map(output, postOperator);
+ output = new com.yahoo.tensor.functions.Map<>(output, postOperator);
}
return output;
}