diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-28 11:27:23 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-28 11:27:23 +0100 |
commit | ea771fd3d188a370d44884646924fc39f21d55c5 (patch) | |
tree | fc8cea8997c8672ea435b932bc9b4a3ffd9ea30c /vespajlib | |
parent | 7b0601cdf8cfe76294f6f39a83b81d584ce7356d (diff) |
Tensor benchmark and convenience functions
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 18 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java | 64 |
2 files changed, 82 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 4b17f65ea21..08f6a1aa6be 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -55,6 +56,19 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); + + /** + * Returns the value of this as a double if it has no dimensions and one value + * + * @throws IllegalStateException if this does not have zero dimensions and one value + */ + default double asDouble() { + if (dimensions().size() > 0) + throw new IllegalStateException("This tensor is no dimensionless. Dimensions: " + dimensions()); + if (cells().size() != 1) + throw new IllegalStateException("This tensor does not have a single value, it has " + cells().size()); + return cells().values().iterator().next(); + } // ----------------- Primitive tensor functions @@ -63,6 +77,10 @@ public interface Tensor { } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ + default Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) { + return new Reduce(new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate(); + } + /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) { return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate(); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java new file mode 100644 index 00000000000..5df54ca9ced --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -0,0 +1,64 @@ +package com.yahoo.tensor; + +import com.yahoo.tensor.functions.Reduce; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Microbenchmark of tensor operations. + * + * @author bratseth + */ +public class TensorFunctionBenchmark { + + private final Random random = new Random(); + + public void benchmark(int iterations) { + List<Tensor> modelVectors = generateVectors(100, 300); + Tensor queryVector = generateVectors(1, 300).get(0); + dotProduct(queryVector, modelVectors, 10); // warmup + long startTime = System.currentTimeMillis(); + dotProduct(queryVector, modelVectors, iterations); + long totalTime = System.currentTimeMillis() - startTime; + System.out.println("Time per join: " + (totalTime / iterations) + " ms"); + } + + private double dotProduct(Tensor tensor, List<Tensor> tensors, int iterations) { + double result = 0; + for (int i = 0 ; i < iterations; i++) + result = dotProduct(tensor, tensors); + return result; + } + + private double dotProduct(Tensor tensor, List<Tensor> tensors) { + double largest = Double.MIN_VALUE; + for (Tensor tensorElement : tensors) { + double dotProduct = tensor.join(tensorElement, (a, b) -> a * b).reduce(Reduce.Aggregator.sum).asDouble(); + if (dotProduct > largest) + largest = dotProduct; + } + System.out.println(largest); + return largest; + } + + private List<Tensor> generateVectors(int vectorCount, int vectorSize) { + List<Tensor> tensors = new ArrayList<>(); + for (int i = 0; i < vectorCount; i++) { + MapTensorBuilder builder = new MapTensorBuilder(); + builder.dimension("x"); + for (int j = 0; j < vectorSize; j++) { + builder.cell().label("x", String.valueOf(j)).value(random.nextDouble()); + } + tensors.add(builder.build()); + } + return tensors; + } + + public static void main(String[] args) { + // Was: 144 ms + new TensorFunctionBenchmark().benchmark(100); + } + +} |