summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java12
1 files changed, 6 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 1f447f2a575..97bfdda385e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
@@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation {
TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension);
TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(AxB),
- ArithmeticOperator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(alpha))));
if (inputs.size() == 3) {
Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function();
TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(cFunction.get()),
- ArithmeticOperator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(beta))));
return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add());
}