diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2020-01-11 14:20:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-11 14:20:10 +0100 |
commit | 2d859a4a3f97e600655c699deccdf35d2a59be66 (patch) | |
tree | 5ea75beb84f6577afc4a8141c70dfdf049ec9f70 /vespajlib | |
parent | 6af036ff1be58aed8806610d5769952ac0192bdc (diff) |
Revert "Revert "Revert "Require equal sizes in join"""
Diffstat (limited to 'vespajlib')
4 files changed, 15 insertions, 27 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index d0c1abe061a..f631b3e1c58 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1278,7 +1278,6 @@ "public void <init>()", "public void <init>(com.yahoo.tensor.TensorType$Value)", "public varargs void <init>(com.yahoo.tensor.TensorType[])", - "public varargs void <init>(boolean, com.yahoo.tensor.TensorType[])", "public void <init>(java.lang.Iterable)", "public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)", "public int rank()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index d8959147ee0..aeed8c33093 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -323,7 +323,7 @@ public class TensorType { * [N] + [] = [] * [] + {} = {} */ - Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) { + Dimension combineWith(Optional<Dimension> other) { if ( ! other.isPresent()) return this; if (this instanceof MappedDimension) return this; if (other.get() instanceof MappedDimension) return other.get(); @@ -333,11 +333,7 @@ public class TensorType { // both are indexed bound IndexedBoundDimension thisIb = (IndexedBoundDimension)this; IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); - if (allowDifferentSizes) - return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; - if ( ! thisIb.size().equals(otherIb.size())) - throw new IllegalArgumentException("Unequal dimension sizes in " + thisIb + " and " + otherIb); - return thisIb; + return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; } @Override @@ -498,13 +494,9 @@ public class TensorType { * The value type will be the largest of the value types of the input types */ public Builder(TensorType ... types) { - this(true, types); - } - - public Builder(boolean allowDifferentSizes, TensorType ... types) { this.valueType = TensorType.combinedValueType(types); for (TensorType type : types) - addDimensionsOf(type, allowDifferentSizes); + addDimensionsOf(type); } /** Creates a builder from the given dimensions, having double as the value type */ @@ -522,17 +514,17 @@ public class TensorType { private static final boolean supportsMixedTypes = false; - private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) { + private void addDimensionsOf(TensorType type) { if ( ! supportsMixedTypes) { // TODO: Support it - addDimensionsOfAndDisallowMixedDimensions(type, allowDifferentSizes); + addDimensionsOfAndDisallowMixedDimensions(type); } else { for (Dimension dimension : type.dimensions) - set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes)); + set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())))); } } - private void addDimensionsOfAndDisallowMixedDimensions(TensorType type, boolean allowDifferentSizes) { + private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) { boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed()); containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed()); @@ -540,7 +532,7 @@ public class TensorType { if (containsMapped) dimension = new MappedDimension(dimension.name()); Dimension existing = dimensions.get(dimension.name()); - set(dimension.combineWith(Optional.ofNullable(existing), allowDifferentSizes)); + set(dimension.combineWith(Optional.ofNullable(existing))); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 5419d04a4fb..1e0eaa7fad3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -48,12 +48,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP /** Returns the type resulting from applying Join to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { - try { - return new TensorType.Builder(false, a, b).build(); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Can not join " + a + " and " + b, e); - } + return new TensorType.Builder(a, b).build(); } public DoubleBinaryOperator combinator() { return combinator; } @@ -80,14 +75,14 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP @Override public TensorType type(TypeContext<NAMETYPE> context) { - return outputType(argumentA.type(context), argumentB.type(context)); + return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType joinedType = outputType(a.type(), b.type()); + TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); return evaluate(a, b, joinedType, combinator); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 5bd1bbdba37..b1851b5f120 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -130,11 +130,13 @@ public class TensorTestCase { assertEquals("Mapped vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.mapped, 2))); assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(3, Type.indexedUnbound), vectors(5, Type.indexedUnbound, 2))); assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(5, Type.indexedUnbound), vectors(3, Type.indexedUnbound, 2))); - assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(3, Type.indexedBound, 2))); + assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(5, Type.indexedBound, 2))); + assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(5, Type.indexedBound), vectors(3, Type.indexedBound, 2))); assertEquals("Mapped matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.mapped, 2))); assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(3, Type.indexedUnbound), matrix(5, Type.indexedUnbound, 2))); assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(5, Type.indexedUnbound), matrix(3, Type.indexedUnbound, 2))); - assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(3, Type.indexedBound, 2))); + assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(5, Type.indexedBound, 2))); + assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(5, Type.indexedBound), matrix(3, Type.indexedBound, 2))); assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2))); assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.indexedUnbound, 2))); |