diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java index 8e49ce15265..b7a8a4a4e43 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -72,14 +73,14 @@ public class Reduce extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! allInputTypesPresent(1)) return null; - TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction<Reference> inputFunction = inputs.get(0).function().get(); if (preOperator != null) { - inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator); + inputFunction = new com.yahoo.tensor.functions.Map<>(inputFunction, preOperator); } - TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions); + TensorFunction<Reference> output = new com.yahoo.tensor.functions.Reduce<>(inputFunction, aggregator, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); @@ -88,12 +89,12 @@ public class Reduce extends IntermediateOperation { } TensorType generatedType = typeBuilder.build(); ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, + Generate<Reference> generatedFunction = new Generate<>(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply()); } if (postOperator != null) { - output = new com.yahoo.tensor.functions.Map(output, postOperator); + output = new com.yahoo.tensor.functions.Map<>(output, postOperator); } return output; } |