diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 1f447f2a575..97bfdda385e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; @@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation { TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(AxB), - ArithmeticOperator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function(); TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(cFunction.get()), - ArithmeticOperator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(beta)))); return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } |