summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-06 08:57:09 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-06 08:57:09 -0800
commit7ef64a61b4f04a400428fe58ed2475aa37c43d39 (patch)
tree590627375d361e3d879285abb4210e70b84a29b0 /vespajlib/src/test/java/com/yahoo/tensor/functions
parente4b328f4ee05b55131420df7f6b5a3685d5dffa5 (diff)
Generalized Slice tensor function
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java138
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java75
2 files changed, 138 insertions, 75 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java
new file mode 100644
index 00000000000..55e6151f7e9
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java
@@ -0,0 +1,138 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class SliceTestCase {
+
+ private static final double delta = 0.000001;
+
+ @Test
+ public void testSliceFunctionGeneralFormToRank0() {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("key", "bar"),
+ new Slice.DimensionValue<>("x", 0)))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(2.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToRank0ReverseDimensionOrder() {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 0),
+ new Slice.DimensionValue<>("key", "bar")))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(2.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToIndexedRank2to1() {
+ Tensor input = Tensor.from("tensor(key{},x[2]):{ {key:foo,x:0}:1.3, {key:foo,x:1}:1.4, {key:bar,x:0}:2.3, {key:bar,x:1}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("key", "bar")))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(x[2]):[2.3, 2.4]]"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank2to1() {
+ Tensor input = Tensor.from("tensor(key{},x[2]):{ {key:foo,x:0}:1.3, {key:foo,x:1}:1.4, {key:bar,x:0}:2.3, {key:bar,x:1}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 0)))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.3, {key:bar}:2.3}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank3to1() {
+ Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 1),
+ new Slice.DimensionValue<>("y", 0)))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.4, {key:bar}:2.4}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank3to1ReverseDimensionOrder() {
+ Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("y", 0),
+ new Slice.DimensionValue<>("x", 1)))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.4, {key:bar}:2.4}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank3to2() {
+ Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 1)))
+ .evaluate();
+ assertEquals(2, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{},y[1]):{{key:foo,y:0}:1.4, {key:bar,y:0}:2.4}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionSingleMappedDimensionToRank0() {
+ Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("foo")))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(1.4, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionSingleIndexedDimensionToRank0() {
+ Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>(2)))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(3.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionShortFormWithMultipleDimensionsIsNotAllowed() {
+ try {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("bar"),
+ new Slice.DimensionValue<>(0)))
+ .evaluate();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Short form of subspace addresses is only supported with a single dimension: Specify dimension names explicitly instead",
+ e.getMessage());
+ }
+ }
+
+ @Test
+ public void testToString() {
+ Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
+ assertEquals("tensor(key[3]):[1.1, 2.2, 3.3][2]",
+ new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>(2)))
+ .toString());
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
deleted file mode 100644
index 227fbffbaa8..00000000000
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor.functions;
-
-import com.yahoo.tensor.Tensor;
-import org.junit.Test;
-
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
-/**
- * @author bratseth
- */
-public class ValueTestCase {
-
- private static final double delta = 0.000001;
-
- @Test
- public void testValueFunctionGeneralForm() {
- Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
- Tensor result = new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>("key", "bar"),
- new Value.DimensionValue<>("x", 0)))
- .evaluate();
- assertEquals(0, result.type().rank());
- assertEquals(2.3, result.asDouble(), delta);
- }
-
- @Test
- public void testValueFunctionSingleMappedDimension() {
- Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }");
- Tensor result = new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>("foo")))
- .evaluate();
- assertEquals(0, result.type().rank());
- assertEquals(1.4, result.asDouble(), delta);
- }
-
- @Test
- public void testValueFunctionSingleIndexedDimension() {
- Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
- Tensor result = new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>(2)))
- .evaluate();
- assertEquals(0, result.type().rank());
- assertEquals(3.3, result.asDouble(), delta);
- }
-
- @Test
- public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() {
- try {
- Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
- new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>("bar"),
- new Value.DimensionValue<>(0)))
- .evaluate();
- fail("Expected exception");
- }
- catch (IllegalArgumentException e) {
- assertEquals("Short form of cell addresses is only supported with a single dimension: Specify dimension names explicitly",
- e.getMessage());
- }
- }
-
- @Test
- public void testToString() {
- Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
- assertEquals("tensor(key[3]):[1.1, 2.2, 3.3][2]",
- new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>(2)))
- .toString());
- }
-
-}