summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-02-01 11:44:31 +0100
committerJon Bratseth <bratseth@oath.com>2018-02-01 11:44:31 +0100
commit99ef288b5023f6879a944eaf2ba325de8997aa50 (patch)
treedc46760052327a41d5a585008aa2a67df670a75b /vespajlib
parentc9044baf967cb8aac50ba63519b9f5b9097d9d8e (diff)
Allow compatible changes to stored tensors
Allow increasing the size of tensor dimensions without making stored data incompatible.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java21
2 files changed, 42 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 7eae9539e77..8ff9774fc7d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -82,6 +82,21 @@ public class TensorType {
* i.e if the given type is a generalization of this type.
*/
public boolean isAssignableTo(TensorType generalization) {
+ return isConvertibleOrAssignableTo(generalization, false);
+ }
+
+ /**
+ * Returns whether this type can be converted to the given type.
+ * This is true if this type isAssignableTo the given type or
+ * if it is not assignable only because it has a shorter dimension length
+ * than the given type in some shared dimension(s), as it can then be
+ * converted to the given type by zero padding.
+ */
+ public boolean isConvertibleTo(TensorType generalization) {
+ return isConvertibleOrAssignableTo(generalization, true);
+ }
+
+ private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible) {
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
@@ -90,7 +105,12 @@ public class TensorType {
if ( ! thisDimension.name().equals(generalizationDimension.name())) return false;
if (generalizationDimension.size().isPresent()) {
if ( ! thisDimension.size().isPresent()) return false;
- if ( ! thisDimension.size().get().equals(generalizationDimension.size().get()) ) return false;
+ if (convertible) {
+ if (thisDimension.size().get() > generalizationDimension.size().get()) return false;
+ }
+ else { // assignable
+ if (!thisDimension.size().get().equals(generalizationDimension.size().get())) return false;
+ }
}
}
return true;
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index c3e31fad2da..eef0b090fd1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -12,6 +12,7 @@ import static org.junit.Assert.fail;
/**
* @author geirst
+ * @author bratseth
*/
public class TensorTypeTestCase {
@@ -75,6 +76,18 @@ public class TensorTypeTestCase {
assertIsAssignableTo("tensor(x{},y[10])", "tensor(x{},y[])");
}
+ @Test
+ public void testConvertibleTo() {
+ assertIsConvertibleTo("tensor(x[])", "tensor(x[])");
+ assertUnconvertibleTo("tensor(x[])", "tensor(y[])");
+ assertIsConvertibleTo("tensor(x[10])", "tensor(x[])");
+ assertUnconvertibleTo("tensor(x[])", "tensor(x[10])");
+ assertUnconvertibleTo("tensor(x[10])", "tensor(x[5])");
+ assertIsConvertibleTo("tensor(x[5])", "tensor(x[10])"); // Different from assignable
+ assertUnconvertibleTo("tensor(x{})", "tensor(x[])");
+ assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])");
+ }
+
private static void assertTensorType(String typeSpec) {
assertTensorType(typeSpec, typeSpec);
}
@@ -100,4 +113,12 @@ public class TensorTypeTestCase {
assertFalse(TensorType.fromSpec(specificType).isAssignableTo(TensorType.fromSpec(generalType)));
}
+ private void assertIsConvertibleTo(String specificType, String generalType) {
+ assertTrue(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType)));
+ }
+
+ private void assertUnconvertibleTo(String specificType, String generalType) {
+ assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType)));
+ }
+
}