diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-02-01 11:44:31 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-02-01 11:44:31 +0100 |
commit | 99ef288b5023f6879a944eaf2ba325de8997aa50 (patch) | |
tree | dc46760052327a41d5a585008aa2a67df670a75b /vespajlib | |
parent | c9044baf967cb8aac50ba63519b9f5b9097d9d8e (diff) |
Allow compatible changes to stored tensors
Allow increasing the size of tensor dimensions without making stored
data incompatible.
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 22 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java | 21 |
2 files changed, 42 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 7eae9539e77..8ff9774fc7d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -82,6 +82,21 @@ public class TensorType { * i.e if the given type is a generalization of this type. */ public boolean isAssignableTo(TensorType generalization) { + return isConvertibleOrAssignableTo(generalization, false); + } + + /** + * Returns whether this type can be converted to the given type. + * This is true if this type isAssignableTo the given type or + * if it is not assignable only because it has a shorter dimension length + * than the given type in some shared dimension(s), as it can then be + * converted to the given type by zero padding. + */ + public boolean isConvertibleTo(TensorType generalization) { + return isConvertibleOrAssignableTo(generalization, true); + } + + private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible) { if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); @@ -90,7 +105,12 @@ public class TensorType { if ( ! thisDimension.name().equals(generalizationDimension.name())) return false; if (generalizationDimension.size().isPresent()) { if ( ! thisDimension.size().isPresent()) return false; - if ( ! thisDimension.size().get().equals(generalizationDimension.size().get()) ) return false; + if (convertible) { + if (thisDimension.size().get() > generalizationDimension.size().get()) return false; + } + else { // assignable + if (!thisDimension.size().get().equals(generalizationDimension.size().get())) return false; + } } } return true; diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index c3e31fad2da..eef0b090fd1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -12,6 +12,7 @@ import static org.junit.Assert.fail; /** * @author geirst + * @author bratseth */ public class TensorTypeTestCase { @@ -75,6 +76,18 @@ public class TensorTypeTestCase { assertIsAssignableTo("tensor(x{},y[10])", "tensor(x{},y[])"); } + @Test + public void testConvertibleTo() { + assertIsConvertibleTo("tensor(x[])", "tensor(x[])"); + assertUnconvertibleTo("tensor(x[])", "tensor(y[])"); + assertIsConvertibleTo("tensor(x[10])", "tensor(x[])"); + assertUnconvertibleTo("tensor(x[])", "tensor(x[10])"); + assertUnconvertibleTo("tensor(x[10])", "tensor(x[5])"); + assertIsConvertibleTo("tensor(x[5])", "tensor(x[10])"); // Different from assignable + assertUnconvertibleTo("tensor(x{})", "tensor(x[])"); + assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])"); + } + private static void assertTensorType(String typeSpec) { assertTensorType(typeSpec, typeSpec); } @@ -100,4 +113,12 @@ public class TensorTypeTestCase { assertFalse(TensorType.fromSpec(specificType).isAssignableTo(TensorType.fromSpec(generalType))); } + private void assertIsConvertibleTo(String specificType, String generalType) { + assertTrue(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType))); + } + + private void assertUnconvertibleTo(String specificType, String generalType) { + assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType))); + } + } |