diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-09-28 22:54:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-28 22:54:13 +0200 |
commit | 12992ecdc0e77968eb5c5544f2ae7d855e443162 (patch) | |
tree | ac8cec3ae02f27ae638876940399f490b4ac4ab1 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java | |
parent | d50f7bd9c99ed9d8edeabb71825f3966f9cd6bd9 (diff) | |
parent | fb0074925e9e8358d38145dc5753de1c935f737d (diff) |
Merge pull request #24251 from vespa-engine/bratseth/operatorsv8.61.17
Bratseth/operators
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 1f447f2a575..97bfdda385e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; @@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation { TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(AxB), - ArithmeticOperator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function(); TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(cFunction.get()), - ArithmeticOperator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(beta)))); return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } |