summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java11
1 files changed, 6 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
index b8ca114343d..902144cfea2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -56,11 +57,11 @@ public class Sum extends IntermediateOperation {
// optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
- TensorFunction inputFunction = inputs.get(0).function().get();
- TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.sum, reduceDimensions);
+ TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
+ TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.sum, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
@@ -69,9 +70,9 @@ public class Sum extends IntermediateOperation {
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
+ Generate<Reference> generatedFunction = new Generate<>(generatedType,
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply());
}
return output;
}