diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index b8ca114343d..902144cfea2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -56,11 +57,11 @@ public class Sum extends IntermediateOperation { // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); - TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> output = new Reduce<>(inputFunction, Reduce.Aggregator.sum, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -69,9 +70,9 @@ public class Sum extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } return output; } |