diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 1dbfd6e40dc..9a76662529d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + if ( ! allInputTypesPresent(2)) return null; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); return typeBuilder.build(); @@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType aType = inputs.get(0).type().get(); OrderedTensorType bType = inputs.get(1).type().get(); if (aType.type().rank() < 2 || bType.type().rank() < 2) @@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } + if ( ! allInputTypesPresent(2)) return; + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); @@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation { renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); } + } |