aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
commit3cf8be5fe60504a02be04009b9348913ae32b564 (patch)
tree2efea86685fb8c94725e71628ddaeaa683a1faed /vespajlib/src/test
parent74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff)
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'vespajlib/src/test')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java33
1 files changed, 33 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index 0a6c821e64e..d145ad7a316 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -9,6 +9,7 @@ import java.util.Iterator;
import java.util.Map;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -96,6 +97,38 @@ public class IndexedTensorTestCase {
}
@Test
+ public void testDirectIndexedAddress() {
+ TensorType type = new TensorType.Builder().indexed("v", 3)
+ .indexed("w", wSize)
+ .indexed("x", xSize)
+ .indexed("y", ySize)
+ .indexed("z", zSize)
+ .build();
+ var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type));
+ assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5));
+ assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7));
+ assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0));
+ assertEquals(xSize*ySize*zSize, directAddress.getStride(1));
+ assertEquals(ySize*zSize, directAddress.getStride(2));
+ assertEquals(zSize, directAddress.getStride(3));
+ assertEquals(1, directAddress.getStride(4));
+ assertEquals(0, directAddress.getIndex());
+ directAddress.setIndex(0,1);
+ assertEquals(1 * directAddress.getStride(0), directAddress.getIndex());
+ directAddress.setIndex(1,1);
+ assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getIndex());
+ directAddress.setIndex(2,2);
+ directAddress.setIndex(3,2);
+ directAddress.setIndex(4,2);
+ long expected = 1 * directAddress.getStride(0) +
+ 1 * directAddress.getStride(1) +
+ 2 * directAddress.getStride(2) +
+ 2 * directAddress.getStride(3) +
+ 2 * directAddress.getStride(4);
+ assertEquals(expected, directAddress.getIndex());
+ }
+
+ @Test
public void testUnboundBuilding() {
TensorType type = new TensorType.Builder().indexed("w")
.indexed("v")