summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java17
1 files changed, 9 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 4b3208fdeb0..1f447f2a575 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
@@ -78,7 +79,7 @@ public class Gemm extends IntermediateOperation {
}
@Override
- protected TensorFunction lazyGetFunction() {
+ protected TensorFunction<Reference> lazyGetFunction() {
if ( ! check2or3InputsPresent()) return null;
OrderedTensorType aType = inputs.get(0).type().get();
@@ -86,29 +87,29 @@ public class Gemm extends IntermediateOperation {
if (aType.type().rank() != 2 || bType.type().rank() != 2)
throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2");
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
+ Optional<TensorFunction<Reference>> aFunction = inputs.get(0).function();
+ Optional<TensorFunction<Reference>> bFunction = inputs.get(1).function();
if (aFunction.isEmpty() || bFunction.isEmpty()) {
return null;
}
String joinDimension = aType.dimensions().get(1 - transposeA).name();
- TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension);
- TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
+ TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension);
+ TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
new ArithmeticNode(
new TensorFunctionNode(AxB),
ArithmeticOperator.MULTIPLY,
new ConstantNode(new DoubleValue(alpha))));
if (inputs.size() == 3) {
- Optional<TensorFunction> cFunction = inputs.get(2).function();
- TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction(
+ Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function();
+ TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction(
new ArithmeticNode(
new TensorFunctionNode(cFunction.get()),
ArithmeticOperator.MULTIPLY,
new ConstantNode(new DoubleValue(beta))));
- return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add());
+ return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add());
}
return alphaxAxB;