diff options
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java index 3ed8a7237ec..4ab60ecb9b9 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java @@ -2,9 +2,14 @@ package com.yahoo.tensor; +import com.yahoo.tensor.functions.Reduce; import org.junit.Test; +import java.util.HashMap; +import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -159,4 +164,25 @@ public class MixedTensorTestCase { tensor.toString()); } + @Test + public void testSplitIntoDense() { + TensorType type = new TensorType.Builder().mapped("key").indexed("x", 3).build(); + Tensor tensor = MixedTensor.Builder.of(type). + cell().label("key", "key1").label("x", 0).value(1). + cell().label("key", "key1").label("x", 1).value(2). + cell().label("key", "key1").label("x", 2).value(3). + cell().label("key", "key2").label("x", 0).value(4). + cell().label("key", "key2").label("x", 1).value(5). + cell().label("key", "key2").label("x", 2).value(6). + build(); + + Map<String, Tensor> indexedTensors = new HashMap<>(); + tensor.sum("x").cellIterator() + .forEachRemaining(cell -> indexedTensors.put(cell.getKey().label(0), + tensor.multiply(Tensor.Builder.of(type.mappedSubtype()).cell(cell.getKey(), 1.0).build()).sum("key"))); + + assertEquals("tensor(x[3]):[1.0, 2.0, 3.0]", indexedTensors.get("key1").toString()); + assertEquals("tensor(x[3]):[4.0, 5.0, 6.0]", indexedTensors.get("key2").toString()); + } + } |