diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-02 16:03:43 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-02 16:03:43 +0100 |
commit | 9f05c59b51e83971cd3530c1f4eadbdf071cf0d5 (patch) | |
tree | cca2aac94ed9c9321baf8bd78792341e2b5be267 | |
parent | ded9e870509772e87e7fe42d888d20246e3c7d03 (diff) |
Validate sizes
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 6 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java | 54 |
2 files changed, 40 insertions, 20 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a875b392de7..d94f7f1529a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -83,7 +83,7 @@ public class Concat extends PrimitiveTensorFunction { TensorAddress aAddress = iaSubspace.address(); for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { IndexedTensor.SubspaceIterator ibSubspace = ib.next(); - System.out.println(" Producing concatenation along '" + dimension + " starting at b address" + ibSubspace.address()); + System.out.println(" Producing concatenation along '" + dimension + "' starting at b address " + ibSubspace.address()); while (ibSubspace.hasNext()) { java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, @@ -135,6 +135,10 @@ public class Concat extends PrimitiveTensorFunction { int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0); if (currentDimension.equals(concatDimension)) joinedSizes[i] = aSize + bSize; + else if (aSize != 0 && bSize != 0 && aSize!=bSize ) + throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " + + "concatenating " + a.type() + " and " + b.type() + " along dimension " + + concatDimension + ", but was " + aSize + " and " + bSize); else joinedSizes[i] = Math.max(aSize, bSize); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index 69f2c710d7a..7136248ea0c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -4,6 +4,7 @@ import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * @author bratseth @@ -11,29 +12,44 @@ import static org.junit.Assert.assertEquals; public class ConcatTestCase { @Test - public void testConcat() { - { - Tensor a = Tensor.from("{1}"); - Tensor b = Tensor.from("{2}"); - assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x")); - } + public void testConcatNumbers() { + Tensor a = Tensor.from("{1}"); + Tensor b = Tensor.from("{2}"); + assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x")); + } + + @Test + public void testConcatEqualShapes() { + Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); + Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + + "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y")); + } + + @Test + public void testConcatNumberAndVector() { + Tensor a = Tensor.from("{1}"); + Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); + assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + + "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y")); + } - { - Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); - Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x")); + @Test + public void testUnequalEqualShapes() { + Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); + Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x")); + try { assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y")); + fail("Expected exception"); } - - { - Tensor a = Tensor.from("{1}"); - Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); - assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + - "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y")); + catch (IllegalArgumentException expected) { + // success } } - + } |