summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-03 10:34:40 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-03 10:34:40 +0100
commit6dfd393b05800e53a0becf85ddd12614c22a77be (patch)
tree787a1388b60c8a7e322d00455ca1433f6f8a08d4 /vespajlib
parent9f05c59b51e83971cd3530c1f4eadbdf071cf0d5 (diff)
More tests and remove println
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java26
2 files changed, 21 insertions, 9 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index d94f7f1529a..6c02cd08294 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -68,9 +68,7 @@ public class Concat extends PrimitiveTensorFunction {
int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
- System.out.println("Concatenating " + a + " to " + b);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
- System.out.println("Concatenating " + b + " to " + a);
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
@@ -83,14 +81,12 @@ public class Concat extends PrimitiveTensorFunction {
TensorAddress aAddress = iaSubspace.address();
for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
IndexedTensor.SubspaceIterator ibSubspace = ib.next();
- System.out.println(" Producing concatenation along '" + dimension + "' starting at b address " + ibSubspace.address());
while (ibSubspace.hasNext()) {
java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry
TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
concatType, offset, dimension);
if (combinedAddress == null) continue; // incompatible
- System.out.println(" Setting " + combinedAddress + " = " + bCell.getValue());
builder.cell(combinedAddress, bCell.getValue());
}
iaSubspace.reset();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index 7136248ea0c..8a920ad2cec 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -10,7 +10,7 @@ import static org.junit.Assert.fail;
* @author bratseth
*/
public class ConcatTestCase {
-
+
@Test
public void testConcatNumbers() {
Tensor a = Tensor.from("{1}");
@@ -38,7 +38,7 @@ public class ConcatTestCase {
}
@Test
- public void testUnequalEqualShapes() {
+ public void testUnequalEqualSizesSameDimension() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x"));
@@ -46,10 +46,26 @@ public class ConcatTestCase {
assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
"{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y"));
fail("Expected exception");
- }
- catch (IllegalArgumentException expected) {
+ } catch (IllegalArgumentException expected) {
// success
}
}
-}
+ @Test
+ public void testUnequalEqualSizesDifferentDimension() {
+ Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
+ Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
+ assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
+ assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z"));
+ }
+
+ @Test
+ public void testDimensionsubset() {
+ Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }");
+ Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }");
+ assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
+ }
+
+} \ No newline at end of file