summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java28
1 files changed, 28 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 9a76662529d..73aa40927be 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -5,6 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.text.ParenthesisExpressionPrettyPrinter;
+import com.yahoo.text.Text;
import java.util.List;
import java.util.Optional;
@@ -51,6 +53,12 @@ public class MatMul extends IntermediateOperation {
List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+ assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
+ assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+
+ System.out.println("Dimensions in a: " + aDimensions);
+ System.out.println("Dimensions in b: " + bDimensions);
+
String aDim0 = aDimensions.get(0).name();
String aDim1 = aDimensions.get(1).name();
String bDim0 = bDimensions.get(0).name();
@@ -67,4 +75,24 @@ public class MatMul extends IntermediateOperation {
renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
}
+ private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
+ if (dimensions.size() >= 2) return;
+
+
+ throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
+ " but got just " + dimensions + " from\n" +
+ ParenthesisExpressionPrettyPrinter.prettyPrint(supplier.toFullString()));
+ }
+
+ @Override
+ public String toFullString() {
+ return "MatMul(" + inputs().get(0).toFullString() + ", " +
+ inputs().get(1).toFullString() + ")" + " : " + lazyGetType();
+ }
+
+ @Override
+ public String toString() {
+ return "MatMul(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ")";
+ }
+
}