summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-01-10 13:07:54 +0100
committerLester Solbakken <lesters@oath.com>2020-01-10 13:07:54 +0100
commitf9e6262cd5f8b919db33f8a43cad45c25546d2e4 (patch)
treefdbd5e4a7b918f7195f4d01f11e694009760741c /vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
parent1bfeb920e039dd22f586c382c66fef90af6f4459 (diff)
Revert "Revert "Require equal sizes in join""
This reverts commit d78f8b089753025421524539e86ca96b7bf3369c.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java24
1 files changed, 16 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index aeed8c33093..d8959147ee0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -323,7 +323,7 @@ public class TensorType {
* [N] + [] = []
* [] + {} = {}
*/
- Dimension combineWith(Optional<Dimension> other) {
+ Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) {
if ( ! other.isPresent()) return this;
if (this instanceof MappedDimension) return this;
if (other.get() instanceof MappedDimension) return other.get();
@@ -333,7 +333,11 @@ public class TensorType {
// both are indexed bound
IndexedBoundDimension thisIb = (IndexedBoundDimension)this;
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
- return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ if (allowDifferentSizes)
+ return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
+ if ( ! thisIb.size().equals(otherIb.size()))
+ throw new IllegalArgumentException("Unequal dimension sizes in " + thisIb + " and " + otherIb);
+ return thisIb;
}
@Override
@@ -494,9 +498,13 @@ public class TensorType {
* The value type will be the largest of the value types of the input types
*/
public Builder(TensorType ... types) {
+ this(true, types);
+ }
+
+ public Builder(boolean allowDifferentSizes, TensorType ... types) {
this.valueType = TensorType.combinedValueType(types);
for (TensorType type : types)
- addDimensionsOf(type);
+ addDimensionsOf(type, allowDifferentSizes);
}
/** Creates a builder from the given dimensions, having double as the value type */
@@ -514,17 +522,17 @@ public class TensorType {
private static final boolean supportsMixedTypes = false;
- private void addDimensionsOf(TensorType type) {
+ private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) {
if ( ! supportsMixedTypes) { // TODO: Support it
- addDimensionsOfAndDisallowMixedDimensions(type);
+ addDimensionsOfAndDisallowMixedDimensions(type, allowDifferentSizes);
}
else {
for (Dimension dimension : type.dimensions)
- set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name()))));
+ set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes));
}
}
- private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) {
+ private void addDimensionsOfAndDisallowMixedDimensions(TensorType type, boolean allowDifferentSizes) {
boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed());
containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed());
@@ -532,7 +540,7 @@ public class TensorType {
if (containsMapped)
dimension = new MappedDimension(dimension.name());
Dimension existing = dimensions.get(dimension.name());
- set(dimension.combineWith(Optional.ofNullable(existing)));
+ set(dimension.combineWith(Optional.ofNullable(existing), allowDifferentSizes));
}
}