summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-02 16:03:43 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-02 16:03:43 +0100
commit9f05c59b51e83971cd3530c1f4eadbdf071cf0d5 (patch)
treecca2aac94ed9c9321baf8bd78792341e2b5be267 /vespajlib/src/main/java/com/yahoo/tensor
parentded9e870509772e87e7fe42d888d20246e3c7d03 (diff)
Validate sizes
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java6
1 files changed, 5 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a875b392de7..d94f7f1529a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -83,7 +83,7 @@ public class Concat extends PrimitiveTensorFunction {
TensorAddress aAddress = iaSubspace.address();
for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
IndexedTensor.SubspaceIterator ibSubspace = ib.next();
- System.out.println(" Producing concatenation along '" + dimension + " starting at b address" + ibSubspace.address());
+ System.out.println(" Producing concatenation along '" + dimension + "' starting at b address " + ibSubspace.address());
while (ibSubspace.hasNext()) {
java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry
TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
@@ -135,6 +135,10 @@ public class Concat extends PrimitiveTensorFunction {
int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0);
if (currentDimension.equals(concatDimension))
joinedSizes[i] = aSize + bSize;
+ else if (aSize != 0 && bSize != 0 && aSize!=bSize )
+ throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " +
+ "concatenating " + a.type() + " and " + b.type() + " along dimension " +
+ concatDimension + ", but was " + aSize + " and " + bSize);
else
joinedSizes[i] = Math.max(aSize, bSize);
}