diff options
Diffstat (limited to 'vespajlib/src/test')
5 files changed, 120 insertions, 20 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index afc95d295f0..528ca57d256 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -46,12 +46,7 @@ public class IndexedTensorTestCase { @Test public void testNegativeLabels() { - TensorAddress numeric = TensorAddress.of(-1, 0, 1, 1234567, -1234567); - assertEquals("-1", numeric.label(0)); - assertEquals("0", numeric.label(1)); - assertEquals("1", numeric.label(2)); - assertEquals("1234567", numeric.label(3)); - assertEquals("-1234567", numeric.label(4)); + assertThrows(IndexOutOfBoundsException.class, () ->TensorAddress.of(-1, 0, 1, 1234567, -1234567)); } private void verifyFloat(String spec) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java new file mode 100644 index 00000000000..472ebca2360 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java @@ -0,0 +1,72 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor; + +import static com.yahoo.tensor.TensorAddress.of; +import static com.yahoo.tensor.TensorAddress.ofLabels; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +/** + * Test for tensor address. + * + * @author baldersheim + */ +public class TensorAddressTestCase { + public static void equal(TensorAddress a, TensorAddress b) { + assertEquals(a.hashCode(), b.hashCode()); + assertEquals(a, b); + assertEquals(a.size(), b.size()); + for (int i = 0; i < a.size(); i++) { + assertEquals(a.label(i), b.label(i)); + assertEquals(a.numericLabel(i), b.numericLabel(i)); + } + } + public static void notEqual(TensorAddress a, TensorAddress b) { + assertNotEquals(a.hashCode(), b.hashCode()); // This might not hold, but is bad if not very rare + assertNotEquals(a, b); + } + @Test + void testStringVersusNumericAddressEquality() { + equal(ofLabels("1"), of(1)); + } + @Test + void testInEquality() { + notEqual(ofLabels("1"), ofLabels("2")); + notEqual(of(1), of(2)); + } + @Test + void testDimensionsEffectsEqualityAndHash() { + notEqual(ofLabels("1"), ofLabels("1", "1")); + notEqual(of(1), of(1, 1)); + } + @Test + void testAllowNullDimension() { + TensorAddress s1 = ofLabels("1", null, "2"); + TensorAddress s2 = ofLabels("1", "2"); + assertNotEquals(s1, s2); + assertEquals(-1, s1.numericLabel(1)); + assertEquals(null, s1.label(1)); + } + + private static void verifyWithLabel(int dimensions) { + int [] indexes = new int[dimensions]; + Arrays.fill(indexes, 1); + TensorAddress next = of(indexes); + for (int i = 0; i < dimensions; i++) { + indexes[i] = 3; + assertEquals(of(indexes), next = next.withLabel(i, 3)); + } + } + @Test + void testWithLabel() { + for (int i=0; i < 10; i++) { + verifyWithLabel(i); + } + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 74237a218fb..91880c9af93 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -73,7 +73,7 @@ public class TensorFunctionBenchmark { for (int i = 0; i < vectorCount; i++) { Tensor.Builder builder = Tensor.Builder.of(type); for (int j = 0; j < vectorSize; j++) { - builder.cell().label("x", String.valueOf(j)).value(random.nextDouble()); + builder.cell().label("x", j).value(random.nextDouble()); } tensors.add(builder.build()); } @@ -88,8 +88,8 @@ public class TensorFunctionBenchmark { for (int i = 0; i < vectorCount; i++) { for (int j = 0; j < vectorSize; j++) { builder.cell() - .label("i", String.valueOf(i)) - .label("x", String.valueOf(j)) + .label("i", i) + .label("x", j) .value(random.nextDouble()); } } @@ -110,6 +110,7 @@ public class TensorFunctionBenchmark { double time = 0; // ---------------- Indexed unbound: + time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); @@ -132,6 +133,7 @@ public class TensorFunctionBenchmark { // ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations): time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time); + time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time); @@ -143,16 +145,16 @@ public class TensorFunctionBenchmark { System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); /* 2.4Ghz Intel Core i9, Macbook Pro 2019 - * Indexed unbound vectors, time per join: 0,067 ms - * Indexed unbound matrix, time per join: 0,107 ms - * Indexed bound vectors, time per join: 0,068 ms - * Indexed bound matrix, time per join: 0,105 ms - * Mapped vectors, time per join: 1,342 ms - * Mapped matrix, time per join: 3,448 ms - * Indexed vectors, x space time per join: 6,398 ms - * Indexed matrix, x space time per join: 3,220 ms - * Mapped vectors, x space time per join: 14,984 ms - * Mapped matrix, x space time per join: 19,873 ms + Indexed unbound vectors, time per join: 0,066 ms + Indexed unbound matrix, time per join: 0,108 ms + Indexed bound vectors, time per join: 0,068 ms + Indexed bound matrix, time per join: 0,106 ms + Mapped vectors, time per join: 0,845 ms + Mapped matrix, time per join: 1,779 ms + Indexed vectors, x space time per join: 5,778 ms + Indexed matrix, x space time per join: 3,342 ms + Mapped vectors, x space time per join: 8,184 ms + Mapped matrix, x space time per join: 11,547 ms */ } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 7cf0bd35b38..85619dca16c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -33,7 +33,7 @@ public class DynamicTensorTestCase { public void testDynamicMappedRank1TensorFunction() { TensorType sparse = TensorType.fromSpec("tensor(x{})"); DynamicTensor<Name> t2 = DynamicTensor.from(sparse, - Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), + java.util.Map.of(new TensorAddress.Builder(sparse).add("x", "a").build(), new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java new file mode 100644 index 00000000000..ae13b95052b --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java @@ -0,0 +1,31 @@ +package com.yahoo.tensor.impl; + +import static com.yahoo.tensor.impl.TensorAddressAny.of; +import static com.yahoo.tensor.TensorAddressTestCase.equal; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class TensorAddressAnyTestCase { + @Test + void testSize() { + for (int i = 0; i < 10; i++) { + int [] indexes = new int [i]; + assertEquals(i, of(indexes).size()); + } + } + + @Test + void testNumericStringEquality() { + for (int i = 0; i < 10; i++) { + int [] numericIndexes = new int [i]; + String [] stringIndexes = new String[i]; + for (int j = 0; j < i; j++) { + numericIndexes[j] = j; + stringIndexes[j] = String.valueOf(j); + } + equal(of(stringIndexes), of(numericIndexes)); + } + } + +} |