diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java | 159 |
1 files changed, 123 insertions, 36 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 6849e64641e..2b0af93fd8e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -3,13 +3,23 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.Slice; import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.text.ExpressionFormatter; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + public class MatMul extends IntermediateOperation { public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { @@ -20,62 +30,139 @@ public class MatMul extends IntermediateOperation { protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); + + if (typeA.type().rank() < 1 || typeB.type().rank() < 1) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1"); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB; + OrderedTensorType smallestRankType = typeA.rank() >= typeB.rank() ? typeB : typeA; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + TensorType.Dimension dim = largestRankType.dimensions().get(i); + // broadcasting + int j = smallestRankType.rank() - largestRankType.rank() + i; + if (j >= 0 && smallestRankType.dimensions().get(j).size().get() > dim.size().get()) { + dim = smallestRankType.dimensions().get(j); + } + typeBuilder.add(dim); + } + if (typeA.rank() >= 2) { + typeBuilder.add(typeA.dimensions().get(typeA.rank() - 2)); + } + if (typeB.rank() >= 2) { + typeBuilder.add(typeB.dimensions().get(typeB.rank() - 1)); + } return typeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; + + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); + + TensorFunction functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB); + TensorFunction functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA); + + return new com.yahoo.tensor.functions.Reduce( + new Join(functionA, functionB, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + typeA.dimensions().get(typeA.rank() - 1).name()); + } - OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; + private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) { + List<Slice.DimensionValue> slices = new ArrayList<>(); + for (int i = 0; i < typeA.rank() - 2; ++i) { + long dimSizeA = typeA.dimensions().get(i).size().get(); + String dimNameA = typeA.dimensionNames().get(i); + int j = typeB.rank() - typeA.rank() + i; + if (j >= 0) { + long dimSizeB = typeB.dimensions().get(j).size().get(); + if (dimSizeB > dimSizeA && dimSizeA == 1) { + ExpressionNode dimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); + slices.add(new Slice.DimensionValue(Optional.of(dimNameA), wrapScalar(dimensionExpression))); + } + } } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices); } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { if ( ! allInputTypesPresent(2)) return; - List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); - - assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); - assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); + String lastDimA = typeA.dimensions().get(typeA.rank()-1).name(); + String lastDimB = typeB.dimensions().get(typeB.rank()-1).name(); + String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name(); + String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name(); - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); + // The last dimension of A should have the same name as the second-to-last dimension of B + renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this); - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); + // The second-to-last dimension of a should have a different name than the last dimension of b + if (typeA.rank() >= 2 && typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this); + } // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); - } + if (typeA.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this); + } + if (typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this); + } + + // Handle different cases when at least one of the tensors have rank > 2 + for (int i = 0; i < typeA.rank() - 2; ++i) { + String iDim = typeA.dimensionNames().get(i); + + // a1 < a2 < a3 < a4 + for (int j = i+1; j < typeA.rank(); ++j) { + String jDim = typeA.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this); + } + // not equal to last 2 dimensions in B + for (int j = typeB.rank()-2; j < typeB.rank(); ++j) { + if (j < 0) continue; + String jDim = typeB.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this); + } + // equal to matching dimension in tensor B + int j = typeB.rank() - typeA.rank() + i; + if (j >= 0) { + String jDim = typeB.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this); + } + } - private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { - if (dimensions.size() >= 2) return; - throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + - " but got just " + dimensions + " from\n" + - ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); + for (int i = 0; i < typeB.rank() - 2; ++i) { + String iDim = typeB.dimensionNames().get(i); + + // b1 < b2 < b3 < b4 + for (int j = i+1; j < typeB.rank(); ++j) { + String jDim = typeB.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this); + } + // not equal to last 2 dimensions in A + for (int j = typeA.rank()-2; j < typeA.rank(); ++j) { + if (j < 0) continue; + String jDim = typeA.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this); + } + // equal to matching dimension in tensor A + int j = typeA.rank() - typeB.rank() + i; + if (j >= 0) { + String jDim = typeA.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this); + } + } } @Override |