summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-10-18 13:55:30 +0200
committerJon Bratseth <bratseth@yahoo-inc.com>2017-10-18 13:55:30 +0200
commit402ea31a2b4b1161efb3c2a9675f712a8c2b6718 (patch)
tree907b6e485ec26077a696b0aae15db57633d0b120 /vespajlib
parent3f28b1429b9b91070451c698f7201286763b86c1 (diff)
Nonfunctional changes only
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java13
2 files changed, 15 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 5ad1a8f1e17..5e3af70cba4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -112,6 +112,10 @@ public interface Tensor {
Collections.singletonList(toDimension)).evaluate();
}
+ default Tensor concat(double argument, String dimension) {
+ return concat(Tensor.Builder.of(TensorType.empty).cell(argument).build(), dimension);
+ }
+
default Tensor concat(Tensor argument, String dimension) {
return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index c47bfe84373..30078b4a826 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -71,9 +71,18 @@ public class TensorTestCase {
Tensor y = Tensor.from("{{y:1}:3}");
Tensor x = Tensor.from("{{x:0}:5,{x:1}:7}");
Tensor xy = Tensor.from("{{x:0,y:1}:11, {x:1,y:1}:13}");
- double nest = y.multiply(x.multiply(xy).sum("x")).sum("y").asDouble();
+ double nest1 = y.multiply(x.multiply(xy).sum("x")).sum("y").asDouble();
+ double nest2 = x.multiply(xy).sum("x").multiply(y).sum("y").asDouble();
double flat = y.multiply(x).multiply(xy).sum(ImmutableList.of("x","y")).asDouble();
- assertEquals(nest, flat, 0.000000001);
+ assertEquals(nest1, flat, 0.000000001);
+ assertEquals(nest2, flat, 0.000000001);
+ }
+
+ @Test
+ public void testCombineInDimensionIndexed() {
+ Tensor input = Tensor.from("tensor(input[]):{{input:0}:3, {input:1}:7}");
+ Tensor result = input.concat(11, "input");
+ assertEquals("{{input:0}:3.0,{input:1}:7.0,{input:2}:11.0}", result.toString());
}
/** All functions are more throughly tested in searchlib EvaluationTestCase */