diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 4 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java | 26 |
2 files changed, 21 insertions, 9 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index d94f7f1529a..6c02cd08294 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -68,9 +68,7 @@ public class Concat extends PrimitiveTensorFunction { int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new); int[] aToIndexes = mapIndexes(a.type(), concatType); int[] bToIndexes = mapIndexes(b.type(), concatType); - System.out.println("Concatenating " + a + " to " + b); concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); - System.out.println("Concatenating " + b + " to " + a); concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); return builder.build(); } @@ -83,14 +81,12 @@ public class Concat extends PrimitiveTensorFunction { TensorAddress aAddress = iaSubspace.address(); for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { IndexedTensor.SubspaceIterator ibSubspace = ib.next(); - System.out.println(" Producing concatenation along '" + dimension + "' starting at b address " + ibSubspace.address()); while (ibSubspace.hasNext()) { java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, concatType, offset, dimension); if (combinedAddress == null) continue; // incompatible - System.out.println(" Setting " + combinedAddress + " = " + bCell.getValue()); builder.cell(combinedAddress, bCell.getValue()); } iaSubspace.reset(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index 7136248ea0c..8a920ad2cec 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -10,7 +10,7 @@ import static org.junit.Assert.fail; * @author bratseth */ public class ConcatTestCase { - + @Test public void testConcatNumbers() { Tensor a = Tensor.from("{1}"); @@ -38,7 +38,7 @@ public class ConcatTestCase { } @Test - public void testUnequalEqualShapes() { + public void testUnequalEqualSizesSameDimension() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x")); @@ -46,10 +46,26 @@ public class ConcatTestCase { assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y")); fail("Expected exception"); - } - catch (IllegalArgumentException expected) { + } catch (IllegalArgumentException expected) { // success } } -} + @Test + public void testUnequalEqualSizesDifferentDimension() { + Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); + Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); + assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); + assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z")); + } + + @Test + public void testDimensionsubset() { + Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }"); + Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }"); + assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); + } + +}
\ No newline at end of file |