diff options
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 6 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java | 10 |
2 files changed, 16 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 2cba2e05536..5ad1a8f1e17 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -174,11 +174,17 @@ public interface Tensor { default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } + default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); } default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } + default Tensor count(String dimension) { return count(Collections.singletonList(dimension)); } default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); } + default Tensor max(String dimension) { return max(Collections.singletonList(dimension)); } default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); } + default Tensor min(String dimension) { return min(Collections.singletonList(dimension)); } default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); } + default Tensor prod(String dimension) { return prod(Collections.singletonList(dimension)); } default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); } + default Tensor sum(String dimension) { return sum(Collections.singletonList(dimension)); } default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); } // ----------------- serialization diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index f0b800bea7f..c47bfe84373 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -66,6 +66,16 @@ public class TensorTestCase { assertTrue(dimensions3.contains("d3")); } + @Test + public void testExpressions() { + Tensor y = Tensor.from("{{y:1}:3}"); + Tensor x = Tensor.from("{{x:0}:5,{x:1}:7}"); + Tensor xy = Tensor.from("{{x:0,y:1}:11, {x:1,y:1}:13}"); + double nest = y.multiply(x.multiply(xy).sum("x")).sum("y").asDouble(); + double flat = y.multiply(x).multiply(xy).sum(ImmutableList.of("x","y")).asDouble(); + assertEquals(nest, flat, 0.000000001); + } + /** All functions are more throughly tested in searchlib EvaluationTestCase */ @Test public void testTensorComputation() { |