summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-02 16:03:43 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-02 16:03:43 +0100
commit9f05c59b51e83971cd3530c1f4eadbdf071cf0d5 (patch)
treecca2aac94ed9c9321baf8bd78792341e2b5be267
parentded9e870509772e87e7fe42d888d20246e3c7d03 (diff)
Validate sizes
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java54
2 files changed, 40 insertions, 20 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a875b392de7..d94f7f1529a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -83,7 +83,7 @@ public class Concat extends PrimitiveTensorFunction {
TensorAddress aAddress = iaSubspace.address();
for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
IndexedTensor.SubspaceIterator ibSubspace = ib.next();
- System.out.println(" Producing concatenation along '" + dimension + " starting at b address" + ibSubspace.address());
+ System.out.println(" Producing concatenation along '" + dimension + "' starting at b address " + ibSubspace.address());
while (ibSubspace.hasNext()) {
java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry
TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
@@ -135,6 +135,10 @@ public class Concat extends PrimitiveTensorFunction {
int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0);
if (currentDimension.equals(concatDimension))
joinedSizes[i] = aSize + bSize;
+ else if (aSize != 0 && bSize != 0 && aSize!=bSize )
+ throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " +
+ "concatenating " + a.type() + " and " + b.type() + " along dimension " +
+ concatDimension + ", but was " + aSize + " and " + bSize);
else
joinedSizes[i] = Math.max(aSize, bSize);
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index 69f2c710d7a..7136248ea0c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -4,6 +4,7 @@ import com.yahoo.tensor.Tensor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
/**
* @author bratseth
@@ -11,29 +12,44 @@ import static org.junit.Assert.assertEquals;
public class ConcatTestCase {
@Test
- public void testConcat() {
- {
- Tensor a = Tensor.from("{1}");
- Tensor b = Tensor.from("{2}");
- assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x"));
- }
+ public void testConcatNumbers() {
+ Tensor a = Tensor.from("{1}");
+ Tensor b = Tensor.from("{2}");
+ assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x"));
+ }
+
+ @Test
+ public void testConcatEqualShapes() {
+ Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
+ Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
+ assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
+ "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y"));
+ }
+
+ @Test
+ public void testConcatNumberAndVector() {
+ Tensor a = Tensor.from("{1}");
+ Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
+ assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
+ "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y"));
+ }
- {
- Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
- Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x"));
+ @Test
+ public void testUnequalEqualShapes() {
+ Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
+ Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
+ assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x"));
+ try {
assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
"{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y"));
+ fail("Expected exception");
}
-
- {
- Tensor a = Tensor.from("{1}");
- Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
- assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
- "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y"));
+ catch (IllegalArgumentException expected) {
+ // success
}
}
-
+
}