diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-03 13:02:20 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-03 13:02:20 +0100 |
commit | e36af0b3a78fb8fc76c50eeb8392ef09e7c46ebb (patch) | |
tree | 009852649c9c4fc9a3c7c3c28a19873b4fa4977b | |
parent | 798f2fd1d9c85febd9bb56ccb4866c37826c3b43 (diff) |
Require equal sizes in join
9 files changed, 50 insertions, 39 deletions
diff --git a/config-model/src/test/derived/tensor/attributes.cfg b/config-model/src/test/derived/tensor/attributes.cfg index 2e0a207d249..4634e120a3a 100644 --- a/config-model/src/test/derived/tensor/attributes.cfg +++ b/config-model/src/test/derived/tensor/attributes.cfg @@ -59,7 +59,7 @@ attribute[].arity 8 attribute[].lowerbound -9223372036854775808 attribute[].upperbound 9223372036854775807 attribute[].densepostinglistthreshold 0.4 -attribute[].tensortype "tensor(x[10],y[20])" +attribute[].tensortype "tensor(x[10],y[10])" attribute[].imported false attribute[].name "f5" attribute[].datatype TENSOR diff --git a/config-model/src/test/derived/tensor/documenttypes.cfg b/config-model/src/test/derived/tensor/documenttypes.cfg index 72fae572b76..68bd394c9d6 100644 --- a/config-model/src/test/derived/tensor/documenttypes.cfg +++ b/config-model/src/test/derived/tensor/documenttypes.cfg @@ -35,7 +35,7 @@ documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x{})" documenttype[].datatype[].sstruct.field[].name "f4" documenttype[].datatype[].sstruct.field[].id 1224191509 documenttype[].datatype[].sstruct.field[].datatype 21 -documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x[10],y[20])" +documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x[10],y[10])" documenttype[].datatype[].sstruct.field[].name "f5" documenttype[].datatype[].sstruct.field[].id 329055840 documenttype[].datatype[].sstruct.field[].datatype 21 diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 29dc39b01ce..7970e05b790 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -4,7 +4,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "unranked" @@ -21,7 +21,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile1" @@ -34,27 +34,27 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile2" rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" -rankprofile[].fef.property[].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)" +rankprofile[].fef.property[].value "reduce(reduce(join(attribute(f4), tensor(x[10],y[10],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile3" rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript" rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)" rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type" -rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])" +rankprofile[].fef.property[].value "tensor(i[10],x[10],y[10])" rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" @@ -64,7 +64,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile4" @@ -77,7 +77,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile5" @@ -90,14 +90,14 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile6" rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript" rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)" rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type" -rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])" +rankprofile[].fef.property[].value "tensor(i[10],x[10],y[10])" rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" @@ -107,7 +107,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile7" @@ -128,7 +128,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile8" @@ -143,7 +143,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile9" @@ -158,7 +158,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" -rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index aa33684a979..a7248fe3200 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -11,7 +11,7 @@ search tensor { field f3 type tensor<double>(x{}) { indexing: attribute | summary } - field f4 type tensor(x[10],y[20]) { + field f4 type tensor(x[10],y[10]) { indexing: attribute | summary } field f5 type tensor<float>(x[10]) { @@ -33,7 +33,7 @@ search tensor { rank-profile profile2 { first-phase { - expression: sum(matmul(attribute(f4), diag(x[2],y[2],z[3]), x)) + expression: sum(matmul(attribute(f4), diag(x[10],y[10],z[3]), x)) } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java index 2cfc8cb575e..50f37486b90 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java @@ -26,7 +26,6 @@ public class TensorOptimizerTestCase { assertWillOptimize("d0[3]", "d0[3]"); assertWillOptimize("d0[1]", "d0[1]", "d0"); assertWillOptimize("d0[2]", "d0[2]", "d0"); - assertWillOptimize("d0[1]", "d0[3]", "d0"); assertWillOptimize("d0[3]", "d0[3]", "d0"); assertWillOptimize("d0[3]", "d0[3],d1[2]", "d0"); assertWillOptimize("d0[3],d1[2]", "d0[3]", "d0"); @@ -34,10 +33,10 @@ public class TensorOptimizerTestCase { assertWillOptimize("d0[2],d1[3]", "d1[3]", "d1"); assertWillOptimize("d0[2],d2[2]", "d1[3],d2[2]", "d2"); assertWillOptimize("d1[2],d2[2]", "d0[3],d2[2]", "d2"); - assertWillOptimize("d0[1],d2[2]", "d1[3],d2[4]", "d2"); - assertWillOptimize("d0[2],d2[2]", "d1[3],d2[4]", "d2"); - assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]"); - assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]", "d0,d1"); + assertWillOptimize("d0[1],d2[4]", "d1[3],d2[4]", "d2"); + assertWillOptimize("d0[2],d2[4]", "d1[3],d2[4]", "d2"); + assertWillOptimize("d0[2],d1[3]", "d0[2],d1[3]"); + assertWillOptimize("d0[2],d1[3]", "d0[2],d1[3]", "d0,d1"); assertWillOptimize("d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3"); assertWillOptimize("d0[1],d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3"); assertWillOptimize("d0[1],d1[2],d2[3]", "d2[3],d3[4],d4[5]", "d2"); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index a4a9a1e1b24..c96a490a9d6 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1280,6 +1280,7 @@ "public void <init>()", "public void <init>(com.yahoo.tensor.TensorType$Value)", "public varargs void <init>(com.yahoo.tensor.TensorType[])", + "public varargs void <init>(boolean, com.yahoo.tensor.TensorType[])", "public void <init>(java.lang.Iterable)", "public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)", "public int rank()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 58cb151875e..a499807105a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -319,7 +319,7 @@ public class TensorType { * [N] + [] = [] * [] + {} = {} */ - Dimension combineWith(Optional<Dimension> other) { + Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) { if ( ! other.isPresent()) return this; if (this instanceof MappedDimension) return this; if (other.get() instanceof MappedDimension) return other.get(); @@ -329,7 +329,11 @@ public class TensorType { // both are indexed bound IndexedBoundDimension thisIb = (IndexedBoundDimension)this; IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); - return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; + if (allowDifferentSizes) + return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; + if ( ! thisIb.size().equals(otherIb.size())) + throw new IllegalArgumentException("Unequal dimension sizes in " + thisIb + " and " + otherIb); + return thisIb; } @Override @@ -483,16 +487,20 @@ public class TensorType { /** * Creates a builder containing a combination of the dimensions of the given types * - * If the same dimension is indexed with different size restrictions the largest size will be used. + * If the same dimension is indexed with different size restrictions the smallest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. * * The value type will be the largest of the value types of the input types */ public Builder(TensorType ... types) { + this(true, types); + } + + public Builder(boolean allowDifferentSizes, TensorType ... types) { this.valueType = TensorType.combinedValueType(types); for (TensorType type : types) - addDimensionsOf(type); + addDimensionsOf(type, allowDifferentSizes); } /** Creates a builder from the given dimensions, having double as the value type */ @@ -510,17 +518,17 @@ public class TensorType { private static final boolean supportsMixedTypes = false; - private void addDimensionsOf(TensorType type) { + private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) { if ( ! supportsMixedTypes) { // TODO: Support it - addDimensionsOfAndDisallowMixedDimensions(type); + addDimensionsOfAndDisallowMixedDimensions(type, allowDifferentSizes); } else { for (Dimension dimension : type.dimensions) - set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())))); + set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes)); } } - private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) { + private void addDimensionsOfAndDisallowMixedDimensions(TensorType type, boolean allowDifferentSizes) { boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed()); containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed()); @@ -528,7 +536,7 @@ public class TensorType { if (containsMapped) dimension = new MappedDimension(dimension.name()); Dimension existing = dimensions.get(dimension.name()); - set(dimension.combineWith(Optional.ofNullable(existing))); + set(dimension.combineWith(Optional.ofNullable(existing), allowDifferentSizes)); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 1e0eaa7fad3..5419d04a4fb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -48,7 +48,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP /** Returns the type resulting from applying Join to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { - return new TensorType.Builder(a, b).build(); + try { + return new TensorType.Builder(false, a, b).build(); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Can not join " + a + " and " + b, e); + } } public DoubleBinaryOperator combinator() { return combinator; } @@ -75,14 +80,14 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP @Override public TensorType type(TypeContext<NAMETYPE> context) { - return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); + return outputType(argumentA.type(context), argumentB.type(context)); } @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); + TensorType joinedType = outputType(a.type(), b.type()); return evaluate(a, b, joinedType, combinator); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 7932f90d797..13b9e9e762f 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -130,13 +130,11 @@ public class TensorTestCase { assertEquals("Mapped vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.mapped, 2))); assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(3, Type.indexedUnbound), vectors(5, Type.indexedUnbound, 2))); assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(5, Type.indexedUnbound), vectors(3, Type.indexedUnbound, 2))); - assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(5, Type.indexedBound, 2))); - assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(5, Type.indexedBound), vectors(3, Type.indexedBound, 2))); + assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(3, Type.indexedBound, 2))); assertEquals("Mapped matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.mapped, 2))); assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(3, Type.indexedUnbound), matrix(5, Type.indexedUnbound, 2))); assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(5, Type.indexedUnbound), matrix(3, Type.indexedUnbound, 2))); - assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(5, Type.indexedBound, 2))); - assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(5, Type.indexedBound), matrix(3, Type.indexedBound, 2))); + assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(3, Type.indexedBound, 2))); assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2))); assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.indexedUnbound, 2))); |