diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java | 131 |
1 files changed, 98 insertions, 33 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 6849e64641e..1eb21eb2a5e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -4,6 +4,9 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.text.ExpressionFormatter; @@ -20,64 +23,126 @@ public class MatMul extends IntermediateOperation { protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + + // add some more checks here + if (aType.type().rank() < 1 || bType.type().rank() < 1) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1"); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + OrderedTensorType largestRankType = aType.rank() >= bType.rank() ? aType : bType; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + typeBuilder.add(largestRankType.dimensions().get(i)); + } + if (aType.rank() >= 2) { + typeBuilder.add(aType.dimensions().get(aType.rank() - 2)); + } + if (bType.rank() >= 2) { + typeBuilder.add(bType.dimensions().get(bType.rank() - 1)); + } return typeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - Optional<TensorFunction> aFunction = inputs.get(0).function(); Optional<TensorFunction> bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; - } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + + // only change to this is for dimensions with size 1 - check in getType + + return new com.yahoo.tensor.functions.Reduce(new Join(aFunction.get(), bFunction.get(), ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + aType.dimensions().get(aType.rank() - 1).name()); } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { if ( ! allInputTypesPresent(2)) return; - List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + /* + * A: a1, a2, a3, a4 + * B: b1, b2, b3, b4 + * + * a4 == b3 + * a3 < b4 + * a3 < a4 + * b4 < b3 + * + * a1 == b1 -> men også størrelsesmessig. + * a2 == b2 + * etc + */ + + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); + + String lastDimA = typeA.dimensions().get(typeA.rank()-1).name(); + String lastDimB = typeB.dimensions().get(typeB.rank()-1).name(); + String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name(); + String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name(); + + // The last dimension of A should have the same name as the second-to-last dimension of B + renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this); - assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); - assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + // For efficiency, the dimensions to join over should be innermost - soft constraint + if (typeA.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this); + } + if (typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this); + } - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); + // The second-to-last dimension of a should have a different name than the last dimension of b + if (typeA.rank() >= 2 && typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this); + } - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); + // a1 < a2 < a3 < a4 + OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + String iDim = largestRankType.dimensionNames().get(i); + for (int j = i+1; j < largestRankType.rank() - 2; ++j) { + String jDim = largestRankType.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this); + } + } + + // TODO: handle non similar sizes + + // a1 == b1 etc + if (typeA.rank() == typeB.rank()) { + for (int i = 0; i < typeA.rank() - 2; ++i) { + renamer.addConstraint(typeA.dimensionNames().get(i), typeB.dimensionNames().get(i), DimensionRenamer.Constraint.equal(false), this); + } + } - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); - // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); - } - private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { - if (dimensions.size() >= 2) return; - throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + - " but got just " + dimensions + " from\n" + - ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); + + // So, what about the other dimensions? +// if (aDimensions.size() > 2) { +// for (int i = 1; i < aDimensions.size(); ++i) { +// renamer.addConstraint(aDimensions.get(0).name(), aDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this); +// } +// for (int i = 0; i < bDimensions.size(); ++i) { +// renamer.addConstraint(aDimensions.get(0).name(), bDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this); +// } +// } + } +// private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { +// if (dimensions.size() >= 2) return; +// throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + +// " but got just " + dimensions + " from\n" + +// ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); +// } + @Override public MatMul withInputs(List<IntermediateOperation> inputs) { return new MatMul(modelName(), name(), inputs); |