summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2020-01-11 14:20:10 +0100
committerGitHub <noreply@github.com>2020-01-11 14:20:10 +0100
commit2d859a4a3f97e600655c699deccdf35d2a59be66 (patch)
tree5ea75beb84f6577afc4a8141c70dfdf049ec9f70 /vespajlib
parent6af036ff1be58aed8806610d5769952ac0192bdc (diff)
Revert "Revert "Revert "Require equal sizes in join"""
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java6
4 files changed, 15 insertions, 27 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index d0c1abe061a..f631b3e1c58 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1278,7 +1278,6 @@
"public void <init>()",
"public void <init>(com.yahoo.tensor.TensorType$Value)",
"public varargs void <init>(com.yahoo.tensor.TensorType[])",
- "public varargs void <init>(boolean, com.yahoo.tensor.TensorType[])",
"public void <init>(java.lang.Iterable)",
"public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)",
"public int rank()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index d8959147ee0..aeed8c33093 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -323,7 +323,7 @@ public class TensorType {
* [N] + [] = []
* [] + {} = {}
*/
- Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) {
+ Dimension combineWith(Optional<Dimension> other) {
if ( ! other.isPresent()) return this;
if (this instanceof MappedDimension) return this;
if (other.get() instanceof MappedDimension) return other.get();
@@ -333,11 +333,7 @@ public class TensorType {
// both are indexed bound
IndexedBoundDimension thisIb = (IndexedBoundDimension)this;
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
- if (allowDifferentSizes)
- return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
- if ( ! thisIb.size().equals(otherIb.size()))
- throw new IllegalArgumentException("Unequal dimension sizes in " + thisIb + " and " + otherIb);
- return thisIb;
+ return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
}
@Override
@@ -498,13 +494,9 @@ public class TensorType {
* The value type will be the largest of the value types of the input types
*/
public Builder(TensorType ... types) {
- this(true, types);
- }
-
- public Builder(boolean allowDifferentSizes, TensorType ... types) {
this.valueType = TensorType.combinedValueType(types);
for (TensorType type : types)
- addDimensionsOf(type, allowDifferentSizes);
+ addDimensionsOf(type);
}
/** Creates a builder from the given dimensions, having double as the value type */
@@ -522,17 +514,17 @@ public class TensorType {
private static final boolean supportsMixedTypes = false;
- private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) {
+ private void addDimensionsOf(TensorType type) {
if ( ! supportsMixedTypes) { // TODO: Support it
- addDimensionsOfAndDisallowMixedDimensions(type, allowDifferentSizes);
+ addDimensionsOfAndDisallowMixedDimensions(type);
}
else {
for (Dimension dimension : type.dimensions)
- set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes));
+ set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name()))));
}
}
- private void addDimensionsOfAndDisallowMixedDimensions(TensorType type, boolean allowDifferentSizes) {
+ private void addDimensionsOfAndDisallowMixedDimensions(TensorType type) {
boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed());
containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed());
@@ -540,7 +532,7 @@ public class TensorType {
if (containsMapped)
dimension = new MappedDimension(dimension.name());
Dimension existing = dimensions.get(dimension.name());
- set(dimension.combineWith(Optional.ofNullable(existing), allowDifferentSizes));
+ set(dimension.combineWith(Optional.ofNullable(existing)));
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 5419d04a4fb..1e0eaa7fad3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -48,12 +48,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
- try {
- return new TensorType.Builder(false, a, b).build();
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
- }
+ return new TensorType.Builder(a, b).build();
}
public DoubleBinaryOperator combinator() { return combinator; }
@@ -80,14 +75,14 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return outputType(argumentA.type(context), argumentB.type(context));
+ return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
}
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType joinedType = outputType(a.type(), b.type());
+ TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
return evaluate(a, b, joinedType, combinator);
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 5bd1bbdba37..b1851b5f120 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -130,11 +130,13 @@ public class TensorTestCase {
assertEquals("Mapped vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.mapped, 2)));
assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(3, Type.indexedUnbound), vectors(5, Type.indexedUnbound, 2)));
assertEquals("Indexed unbound vector", 42, (int)dotProduct(vector(5, Type.indexedUnbound), vectors(3, Type.indexedUnbound, 2)));
- assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(3, Type.indexedBound, 2)));
+ assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(3, Type.indexedBound), vectors(5, Type.indexedBound, 2)));
+ assertEquals("Indexed bound vector", 42, (int)dotProduct(vector(5, Type.indexedBound), vectors(3, Type.indexedBound, 2)));
assertEquals("Mapped matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.mapped, 2)));
assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(3, Type.indexedUnbound), matrix(5, Type.indexedUnbound, 2)));
assertEquals("Indexed unbound matrix", 42, (int)dotProduct(vector(5, Type.indexedUnbound), matrix(3, Type.indexedUnbound, 2)));
- assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(3, Type.indexedBound, 2)));
+ assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(3, Type.indexedBound), matrix(5, Type.indexedBound, 2)));
+ assertEquals("Indexed bound matrix", 42, (int)dotProduct(vector(5, Type.indexedBound), matrix(3, Type.indexedBound, 2)));
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2)));
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.indexedUnbound, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.indexedUnbound, 2)));