aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java159
1 files changed, 123 insertions, 36 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 6849e64641e..2b0af93fd8e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -3,13 +3,23 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.text.ExpressionFormatter;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
public class MatMul extends IntermediateOperation {
public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
@@ -20,62 +30,139 @@ public class MatMul extends IntermediateOperation {
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(2)) return null;
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
+
+ if (typeA.type().rank() < 1 || typeB.type().rank() < 1)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1");
+
OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB;
+ OrderedTensorType smallestRankType = typeA.rank() >= typeB.rank() ? typeB : typeA;
+ for (int i = 0; i < largestRankType.rank() - 2; ++i) {
+ TensorType.Dimension dim = largestRankType.dimensions().get(i);
+ // broadcasting
+ int j = smallestRankType.rank() - largestRankType.rank() + i;
+ if (j >= 0 && smallestRankType.dimensions().get(j).size().get() > dim.size().get()) {
+ dim = smallestRankType.dimensions().get(j);
+ }
+ typeBuilder.add(dim);
+ }
+ if (typeA.rank() >= 2) {
+ typeBuilder.add(typeA.dimensions().get(typeA.rank() - 2));
+ }
+ if (typeB.rank() >= 2) {
+ typeBuilder.add(typeB.dimensions().get(typeB.rank() - 1));
+ }
return typeBuilder.build();
}
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
+
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
+
+ TensorFunction functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB);
+ TensorFunction functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA);
+
+ return new com.yahoo.tensor.functions.Reduce(
+ new Join(functionA, functionB, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ typeA.dimensions().get(typeA.rank() - 1).name());
+ }
- OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
+ private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) {
+ List<Slice.DimensionValue> slices = new ArrayList<>();
+ for (int i = 0; i < typeA.rank() - 2; ++i) {
+ long dimSizeA = typeA.dimensions().get(i).size().get();
+ String dimNameA = typeA.dimensionNames().get(i);
+ int j = typeB.rank() - typeA.rank() + i;
+ if (j >= 0) {
+ long dimSizeB = typeB.dimensions().get(j).size().get();
+ if (dimSizeB > dimSizeA && dimSizeA == 1) {
+ ExpressionNode dimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
+ slices.add(new Slice.DimensionValue(Optional.of(dimNameA), wrapScalar(dimensionExpression)));
+ }
+ }
}
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+ return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices);
}
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
if ( ! allInputTypesPresent(2)) return;
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
-
- assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
- assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
+ String lastDimA = typeA.dimensions().get(typeA.rank()-1).name();
+ String lastDimB = typeB.dimensions().get(typeB.rank()-1).name();
+ String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name();
+ String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name();
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
+ // The last dimension of A should have the same name as the second-to-last dimension of B
+ renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this);
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
+ // The second-to-last dimension of a should have a different name than the last dimension of b
+ if (typeA.rank() >= 2 && typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this);
+ }
// For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
- }
+ if (typeA.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ if (typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this);
+ }
+
+ // Handle different cases when at least one of the tensors have rank > 2
+ for (int i = 0; i < typeA.rank() - 2; ++i) {
+ String iDim = typeA.dimensionNames().get(i);
+
+ // a1 < a2 < a3 < a4
+ for (int j = i+1; j < typeA.rank(); ++j) {
+ String jDim = typeA.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ // not equal to last 2 dimensions in B
+ for (int j = typeB.rank()-2; j < typeB.rank(); ++j) {
+ if (j < 0) continue;
+ String jDim = typeB.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this);
+ }
+ // equal to matching dimension in tensor B
+ int j = typeB.rank() - typeA.rank() + i;
+ if (j >= 0) {
+ String jDim = typeB.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
- private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
- if (dimensions.size() >= 2) return;
- throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
- " but got just " + dimensions + " from\n" +
- ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+ for (int i = 0; i < typeB.rank() - 2; ++i) {
+ String iDim = typeB.dimensionNames().get(i);
+
+ // b1 < b2 < b3 < b4
+ for (int j = i+1; j < typeB.rank(); ++j) {
+ String jDim = typeB.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ // not equal to last 2 dimensions in A
+ for (int j = typeA.rank()-2; j < typeA.rank(); ++j) {
+ if (j < 0) continue;
+ String jDim = typeA.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this);
+ }
+ // equal to matching dimension in tensor A
+ int j = typeA.rank() - typeB.rank() + i;
+ if (j >= 0) {
+ String jDim = typeA.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
}
@Override