summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java18
1 files changed, 8 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 1dbfd6e40dc..9a76662529d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ if ( ! allInputTypesPresent(2)) return null;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
return typeBuilder.build();
@@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
OrderedTensorType aType = inputs.get(0).type().get();
OrderedTensorType bType = inputs.get(1).type().get();
if (aType.type().rank() < 2 || bType.type().rank() < 2)
@@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
+ if ( ! allInputTypesPresent(2)) return;
+
List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
@@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation {
renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
}
+
}