diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 4b3208fdeb0..1f447f2a575 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -78,7 +79,7 @@ public class Gemm extends IntermediateOperation { } @Override - protected TensorFunction lazyGetFunction() { + protected TensorFunction<Reference> lazyGetFunction() { if ( ! check2or3InputsPresent()) return null; OrderedTensorType aType = inputs.get(0).type().get(); @@ -86,29 +87,29 @@ public class Gemm extends IntermediateOperation { if (aType.type().rank() != 2 || bType.type().rank() != 2) throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2"); - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); + Optional<TensorFunction<Reference>> aFunction = inputs.get(0).function(); + Optional<TensorFunction<Reference>> bFunction = inputs.get(1).function(); if (aFunction.isEmpty() || bFunction.isEmpty()) { return null; } String joinDimension = aType.dimensions().get(1 - transposeA).name(); - TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension); - TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( + TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); + TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( new ArithmeticNode( new TensorFunctionNode(AxB), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { - Optional<TensorFunction> cFunction = inputs.get(2).function(); - TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction( + Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function(); + TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( new ArithmeticNode( new TensorFunctionNode(cFunction.get()), ArithmeticOperator.MULTIPLY, new ConstantNode(new DoubleValue(beta)))); - return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add()); + return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } return alphaxAxB; |