diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-10-18 13:55:30 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-10-18 13:55:30 +0200 |
commit | 402ea31a2b4b1161efb3c2a9675f712a8c2b6718 (patch) | |
tree | 907b6e485ec26077a696b0aae15db57633d0b120 /vespajlib | |
parent | 3f28b1429b9b91070451c698f7201286763b86c1 (diff) |
Nonfunctional changes only
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 4 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java | 13 |
2 files changed, 15 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 5ad1a8f1e17..5e3af70cba4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -112,6 +112,10 @@ public interface Tensor { Collections.singletonList(toDimension)).evaluate(); } + default Tensor concat(double argument, String dimension) { + return concat(Tensor.Builder.of(TensorType.empty).cell(argument).build(), dimension); + } + default Tensor concat(Tensor argument, String dimension) { return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index c47bfe84373..30078b4a826 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -71,9 +71,18 @@ public class TensorTestCase { Tensor y = Tensor.from("{{y:1}:3}"); Tensor x = Tensor.from("{{x:0}:5,{x:1}:7}"); Tensor xy = Tensor.from("{{x:0,y:1}:11, {x:1,y:1}:13}"); - double nest = y.multiply(x.multiply(xy).sum("x")).sum("y").asDouble(); + double nest1 = y.multiply(x.multiply(xy).sum("x")).sum("y").asDouble(); + double nest2 = x.multiply(xy).sum("x").multiply(y).sum("y").asDouble(); double flat = y.multiply(x).multiply(xy).sum(ImmutableList.of("x","y")).asDouble(); - assertEquals(nest, flat, 0.000000001); + assertEquals(nest1, flat, 0.000000001); + assertEquals(nest2, flat, 0.000000001); + } + + @Test + public void testCombineInDimensionIndexed() { + Tensor input = Tensor.from("tensor(input[]):{{input:0}:3, {input:1}:7}"); + Tensor result = input.concat(11, "input"); + assertEquals("{{input:0}:3.0,{input:1}:7.0,{input:2}:11.0}", result.toString()); } /** All functions are more throughly tested in searchlib EvaluationTestCase */ |