diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 9a76662529d..73aa40927be 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -5,6 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ParenthesisExpressionPrettyPrinter; +import com.yahoo.text.Text; import java.util.List; import java.util.Optional; @@ -51,6 +53,12 @@ public class MatMul extends IntermediateOperation { List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); + assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + + System.out.println("Dimensions in a: " + aDimensions); + System.out.println("Dimensions in b: " + bDimensions); + String aDim0 = aDimensions.get(0).name(); String aDim1 = aDimensions.get(1).name(); String bDim0 = bDimensions.get(0).name(); @@ -67,4 +75,24 @@ public class MatMul extends IntermediateOperation { renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); } + private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { + if (dimensions.size() >= 2) return; + + + throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + + " but got just " + dimensions + " from\n" + + ParenthesisExpressionPrettyPrinter.prettyPrint(supplier.toFullString())); + } + + @Override + public String toFullString() { + return "MatMul(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ")" + " : " + lazyGetType(); + } + + @Override + public String toString() { + return "MatMul(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ")"; + } + } |