diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-02-02 13:24:45 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-02-02 13:24:45 +0100 |
commit | c288092ac9868efae10437a8c60974050e9fa799 (patch) | |
tree | 537e666d2c5ae7b06fa6e7b4f7ed33b6bb66d013 | |
parent | 468e0e16a60a5feaf6d5eec971ff06078b6bb694 (diff) |
Generalize dimension-wise
3 files changed, 49 insertions, 7 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 4f0ebc1c7e5..66b250736e8 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -78,13 +78,11 @@ public final class IfNode extends CompositeNode { public TensorType type(TypeContext context) { TensorType trueType = trueExpression.type(context); TensorType falseType = falseExpression.type(context); - - // Types of each branch must be compatible; the resulting type is the most general - if (trueType.isAssignableTo(falseType)) return falseType; - if (falseType.isAssignableTo(trueType)) return trueType; - throw new IllegalArgumentException("An if expression must produce compatible types in both " + - "alternatives, but the 'true' type is " + trueType + " while the " + - "'false' type is " + falseType); + return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() -> + new IllegalArgumentException("An if expression must produce compatible types in both " + + "alternatives, but the 'true' type is " + trueType + " while the " + + "'false' type is " + falseType) + ); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java index d1ea0fcf2e4..5cac2215a00 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java @@ -22,9 +22,12 @@ public class TypeResolutionTestCase { context.setType("query('x1')", TensorType.fromSpec("tensor(x[])")); context.setType("query('x2')", TensorType.fromSpec("tensor(x[10])")); context.setType("query('y1')", TensorType.fromSpec("tensor(y[])")); + context.setType("query('xy1')", TensorType.fromSpec("tensor(x[10],y[])")); + context.setType("query('xy2')", TensorType.fromSpec("tensor(x[],y[10])")); assertType("tensor(x[])", "query(x1)", context); assertType("tensor(x[])", "if (1>0, query(x1), query(x2))", context); + assertType("tensor(x[],y[])", "if (1>0, query(xy1), query(xy2))", context); assertIncompatibleType("if (1>0, query(x1), query(y1))", context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 8ff9774fc7d..14cd3e70866 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -136,6 +136,47 @@ public class TensorType { return true; } + /** + * Returns the dimensionwise generalization of this and the given type, or empty if no generalization exists. + * A dimensionwise generalization exists if the two tensors share the same dimensions, and each dimension + * is compatible. + * For example, the dimensionwise generalization of tensor(x[],y[5]) and tensor(x[5],y[]) is tensor(x[],y[]) + */ + public Optional<TensorType> dimensionwiseGeneralizationWith(TensorType other) { + if (this.equals(other)) return Optional.of(this); // shortcut + if (this.dimensions.size() != other.dimensions.size()) return Optional.empty(); + + Builder b = new Builder(); + for (int i = 0; i < dimensions.size(); i++) { + Dimension thisDim = this.dimensions().get(i); + Dimension otherDim = other.dimensions().get(i); + if ( ! thisDim.name().equals(otherDim.name())) return Optional.empty(); + if (thisDim.isIndexed() && otherDim.isIndexed()) { + if (thisDim.size().isPresent() && otherDim.size().isPresent()) { + if ( ! thisDim.size().get().equals(otherDim.size().get())) + return Optional.empty(); + b.dimension(thisDim); // both are equal and bound + } + else if (thisDim.size().isPresent()) { + b.dimension(otherDim); // use the unbound + } + else if (otherDim.size().isPresent()) { + b.dimension(thisDim); // use the unbound + } + else { + b.dimension(thisDim); // both are equal and unbound + } + } + else if ( ! thisDim.isIndexed() && ! otherDim.isIndexed()) { + b.dimension(thisDim); // both are equal and mapped + } + else { + return Optional.empty(); // one indexed and one mapped + } + } + return Optional.of(b.build()); + } + @Override public int hashCode() { return dimensions.hashCode(); |