summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-01-13 21:01:53 +0100
committerGitHub <noreply@github.com>2020-01-13 21:01:53 +0100
commitb24667e4f883a1ad513fcaab3d47ef604611c7d4 (patch)
tree25e15f4df574ec07ca4babe26695657c55ee139f
parent7250e7a5cfafaa8e52a56c7990437be740761093 (diff)
parent6ba68c27681b36ef4c8fd1b3f5b7b03ec8459fc3 (diff)
Merge pull request #11769 from vespa-engine/revert-11750-revert-11745-bratseth/require-equal-sizes-in-join-2
Revert "Revert "Revert "Revert "Require equal sizes in join""""
-rw-r--r--config-model/src/test/derived/tensor/attributes.cfg2
-rw-r--r--config-model/src/test/derived/tensor/documenttypes.cfg2
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg28
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java9
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java6
9 files changed, 49 insertions, 38 deletions
diff --git a/config-model/src/test/derived/tensor/attributes.cfg b/config-model/src/test/derived/tensor/attributes.cfg
index 2e0a207d249..4634e120a3a 100644
--- a/config-model/src/test/derived/tensor/attributes.cfg
+++ b/config-model/src/test/derived/tensor/attributes.cfg
@@ -59,7 +59,7 @@ attribute[].arity 8
attribute[].lowerbound -9223372036854775808
attribute[].upperbound 9223372036854775807
attribute[].densepostinglistthreshold 0.4
-attribute[].tensortype "tensor(x[10],y[20])"
+attribute[].tensortype "tensor(x[10],y[10])"
attribute[].imported false
attribute[].name "f5"
attribute[].datatype TENSOR
diff --git a/config-model/src/test/derived/tensor/documenttypes.cfg b/config-model/src/test/derived/tensor/documenttypes.cfg
index 72fae572b76..68bd394c9d6 100644
--- a/config-model/src/test/derived/tensor/documenttypes.cfg
+++ b/config-model/src/test/derived/tensor/documenttypes.cfg
@@ -35,7 +35,7 @@ documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x{})"
documenttype[].datatype[].sstruct.field[].name "f4"
documenttype[].datatype[].sstruct.field[].id 1224191509
documenttype[].datatype[].sstruct.field[].datatype 21
-documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x[10],y[20])"
+documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x[10],y[10])"
documenttype[].datatype[].sstruct.field[].name "f5"
documenttype[].datatype[].sstruct.field[].id 329055840
documenttype[].datatype[].sstruct.field[].datatype 21
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index 29dc39b01ce..7970e05b790 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -4,7 +4,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "unranked"
@@ -21,7 +21,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile1"
@@ -34,27 +34,27 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile2"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
-rankprofile[].fef.property[].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)"
+rankprofile[].fef.property[].value "reduce(reduce(join(attribute(f4), tensor(x[10],y[10],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile3"
rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript"
rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)"
rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type"
-rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(i[10],x[10],y[10])"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
@@ -64,7 +64,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile4"
@@ -77,7 +77,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile5"
@@ -90,14 +90,14 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile6"
rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript"
rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)"
rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type"
-rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(i[10],x[10],y[10])"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
@@ -107,7 +107,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile7"
@@ -128,7 +128,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile8"
@@ -143,7 +143,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
rankprofile[].name "profile9"
@@ -158,7 +158,7 @@ rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
-rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].value "tensor(x[10],y[10])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index aa33684a979..a7248fe3200 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -11,7 +11,7 @@ search tensor {
field f3 type tensor<double>(x{}) {
indexing: attribute | summary
}
- field f4 type tensor(x[10],y[20]) {
+ field f4 type tensor(x[10],y[10]) {
indexing: attribute | summary
}
field f5 type tensor<float>(x[10]) {
@@ -33,7 +33,7 @@ search tensor {
rank-profile profile2 {
first-phase {
- expression: sum(matmul(attribute(f4), diag(x[2],y[2],z[3]), x))
+ expression: sum(matmul(attribute(f4), diag(x[10],y[10],z[3]), x))
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java
index 2cfc8cb575e..50f37486b90 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java
@@ -26,7 +26,6 @@ public class TensorOptimizerTestCase {
assertWillOptimize("d0[3]", "d0[3]");
assertWillOptimize("d0[1]", "d0[1]", "d0");
assertWillOptimize("d0[2]", "d0[2]", "d0");
- assertWillOptimize("d0[1]", "d0[3]", "d0");
assertWillOptimize("d0[3]", "d0[3]", "d0");
assertWillOptimize("d0[3]", "d0[3],d1[2]", "d0");
assertWillOptimize("d0[3],d1[2]", "d0[3]", "d0");
@@ -34,10 +33,10 @@ public class TensorOptimizerTestCase {
assertWillOptimize("d0[2],d1[3]", "d1[3]", "d1");
assertWillOptimize("d0[2],d2[2]", "d1[3],d2[2]", "d2");
assertWillOptimize("d1[2],d2[2]", "d0[3],d2[2]", "d2");
- assertWillOptimize("d0[1],d2[2]", "d1[3],d2[4]", "d2");
- assertWillOptimize("d0[2],d2[2]", "d1[3],d2[4]", "d2");
- assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]");
- assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]", "d0,d1");
+ assertWillOptimize("d0[1],d2[4]", "d1[3],d2[4]", "d2");
+ assertWillOptimize("d0[2],d2[4]", "d1[3],d2[4]", "d2");
+ assertWillOptimize("d0[2],d1[3]", "d0[2],d1[3]");
+ assertWillOptimize("d0[2],d1[3]", "d0[2],d1[3]", "d0,d1");
assertWillOptimize("d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3");
assertWillOptimize("d0[1],d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3");
assertWillOptimize("d0[1],d1[2],d2[3]", "d2[3],d3[4],d4[5]", "d2");
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 66eb4b1f4e6..b8b6716d879 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1279,6 +1279,7 @@
"public void <init>()",
"public void <init>(com.yahoo.tensor.TensorType$Value)",
"public varargs void <init>(com.yahoo.tensor.TensorType[])",
+ "public varargs void <init>(boolean, com.yahoo.tensor.TensorType[])",
"public void <init>(java.lang.Iterable)",
"public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)",
"public int rank()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index aeed8c33093..d8959147ee0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -323,7 +323,7 @@ public class TensorType {
* [N] + [] = []
* [] + {} = {}
*/
- Dimension combineWith(Optional<Dimension> other) {
+ Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) {
if ( ! other.isPresent()) return this;
if (this instanceof MappedDimension) return this;
if (other.get() instanceof MappedDimension) return other.get();
@@ -333,7 +333,11 @@ public class TensorType {
// both are indexed bound
IndexedBoundDimension thisIb = (IndexedBoundDimension)this;
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
- return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ if (allowDifferentSizes)
+ return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ if ( ! thisIb.size().equals(otherIb.size()))
+ throw new IllegalArgumentException("Unequal dimension sizes in " + thisIb + " and " + otherIb);
+ return thisIb;
}
@Override
@@ -494,9 +498,13 @@ public class TensorType {
* The value type will be the largest of the value types of the input types
*/
public Builder(TensorType ... types) {
+ this(true, types);
+ }
+
+ public Builder(boolean allowDifferentSizes, TensorType ... types) {
this.valueType = TensorType.combinedValueType(types);
for (TensorType type : types)
- addDimensionsOf(type);
+ addDimensionsOf(type, allowDifferentSizes);
}
/** Creates a builder from the given dimensions, having double as the value type */
@@ -514,17 +522,17 @@ public class TensorType {
private static final boolean supportsMixedTypes = false;
- private void addDimensionsOf(TensorType type) {
+ private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) {
if ( ! supportsMixedTypes) { // TODO: Support it
- addDimensionsOfAndDisallowMixedDimensions(type);
+ addDimensionsOfAndDisallowMixedDimensions(type, allowDifferentSizes);
}
else {
for (Dimension dimension : type.dimensions)
- set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name()))));
+ set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes));
}
}
- private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) {
+ private void addDimensionsOfAndDisallowMixedDimensions(TensorType type, boolean allowDifferentSizes) {
boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed());
containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed());
@@ -532,7 +540,7 @@ public class TensorType {
if (containsMapped)
dimension = new MappedDimension(dimension.name());
Dimension existing = dimensions.get(dimension.name());
- set(dimension.combineWith(Optional.ofNullable(existing)));
+ set(dimension.combineWith(Optional.ofNullable(existing), allowDifferentSizes));
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 1e0eaa7fad3..5419d04a4fb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -48,7 +48,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
- return new TensorType.Builder(a, b).build();
+ try {
+ return new TensorType.Builder(false, a, b).build();
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
+ }
}
public DoubleBinaryOperator combinator() { return combinator; }
@@ -75,14 +80,14 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
+ return outputType(argumentA.type(context), argumentB.type(context));
}
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
+ TensorType joinedType = outputType(a.type(), b.type());
return evaluate(a, b, joinedType, combinator);
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index b1851b5f120..5bd1bbdba37 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -130,13 +130,11 @@ public class TensorTestCase {
assertEquals("Mapped vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.mapped, 2)));
assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(3, Type.indexedUnbound), vectors(5, Type.indexedUnbound, 2)));
assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(5, Type.indexedUnbound), vectors(3, Type.indexedUnbound, 2)));
- assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(5, Type.indexedBound, 2)));
- assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(5, Type.indexedBound), vectors(3, Type.indexedBound, 2)));
+ assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(3, Type.indexedBound, 2)));
assertEquals("Mapped matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.mapped, 2)));
assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(3, Type.indexedUnbound), matrix(5, Type.indexedUnbound, 2)));
assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(5, Type.indexedUnbound), matrix(3, Type.indexedUnbound, 2)));
- assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(5, Type.indexedBound, 2)));
- assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(5, Type.indexedBound), matrix(3, Type.indexedBound, 2)));
+ assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(3, Type.indexedBound, 2)));
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2)));
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.indexedUnbound, 2)));