aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2020-01-03 13:02:20 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2020-01-03 13:02:20 +0100
commite36af0b3a78fb8fc76c50eeb8392ef09e7c46ebb (patch)
tree009852649c9c4fc9a3c7c3c28a19873b4fa4977b /vespajlib/src/main/java/com/yahoo
parent798f2fd1d9c85febd9bb56ccb4866c37826c3b43 (diff)
Require equal sizes in join
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java11
2 files changed, 25 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 58cb151875e..a499807105a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -319,7 +319,7 @@ public class TensorType {
* [N] + [] = []
* [] + {} = {}
*/
- Dimension combineWith(Optional<Dimension> other) {
+ Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) {
if ( ! other.isPresent()) return this;
if (this instanceof MappedDimension) return this;
if (other.get() instanceof MappedDimension) return other.get();
@@ -329,7 +329,11 @@ public class TensorType {
// both are indexed bound
IndexedBoundDimension thisIb = (IndexedBoundDimension)this;
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
- return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ if (allowDifferentSizes)
+ return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ if ( ! thisIb.size().equals(otherIb.size()))
+ throw new IllegalArgumentException("Unequal dimension sizes in " + thisIb + " and " + otherIb);
+ return thisIb;
}
@Override
@@ -483,16 +487,20 @@ public class TensorType {
/**
* Creates a builder containing a combination of the dimensions of the given types
*
- * If the same dimension is indexed with different size restrictions the largest size will be used.
+ * If the same dimension is indexed with different size restrictions the smallest size will be used.
* If it is size restricted in one argument but not the other it will not be size restricted.
* If it is indexed in one and mapped in the other it will become mapped.
*
* The value type will be the largest of the value types of the input types
*/
public Builder(TensorType ... types) {
+ this(true, types);
+ }
+
+ public Builder(boolean allowDifferentSizes, TensorType ... types) {
this.valueType = TensorType.combinedValueType(types);
for (TensorType type : types)
- addDimensionsOf(type);
+ addDimensionsOf(type, allowDifferentSizes);
}
/** Creates a builder from the given dimensions, having double as the value type */
@@ -510,17 +518,17 @@ public class TensorType {
private static final boolean supportsMixedTypes = false;
- private void addDimensionsOf(TensorType type) {
+ private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) {
if ( ! supportsMixedTypes) { // TODO: Support it
- addDimensionsOfAndDisallowMixedDimensions(type);
+ addDimensionsOfAndDisallowMixedDimensions(type, allowDifferentSizes);
}
else {
for (Dimension dimension : type.dimensions)
- set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name()))));
+ set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes));
}
}
- private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) {
+ private void addDimensionsOfAndDisallowMixedDimensions(TensorType type, boolean allowDifferentSizes) {
boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed());
containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed());
@@ -528,7 +536,7 @@ public class TensorType {
if (containsMapped)
dimension = new MappedDimension(dimension.name());
Dimension existing = dimensions.get(dimension.name());
- set(dimension.combineWith(Optional.ofNullable(existing)));
+ set(dimension.combineWith(Optional.ofNullable(existing), allowDifferentSizes));
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 1e0eaa7fad3..5419d04a4fb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -48,7 +48,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
- return new TensorType.Builder(a, b).build();
+ try {
+ return new TensorType.Builder(false, a, b).build();
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
+ }
}
public DoubleBinaryOperator combinator() { return combinator; }
@@ -75,14 +80,14 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
+ return outputType(argumentA.type(context), argumentB.type(context));
}
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
+ TensorType joinedType = outputType(a.type(), b.type());
return evaluate(a, b, joinedType, combinator);
}