summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-28 11:27:23 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-28 11:27:23 +0100
commitea771fd3d188a370d44884646924fc39f21d55c5 (patch)
treefc8cea8997c8672ea435b932bc9b4a3ffd9ea30c /vespajlib
parent7b0601cdf8cfe76294f6f39a83b81d584ce7356d (diff)
Tensor benchmark and convenience functions
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java18
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java64
2 files changed, 82 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 4b17f65ea21..08f6a1aa6be 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
@@ -55,6 +56,19 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
+
+ /**
+ * Returns the value of this as a double if it has no dimensions and one value
+ *
+ * @throws IllegalStateException if this does not have zero dimensions and one value
+ */
+ default double asDouble() {
+ if (dimensions().size() > 0)
+ throw new IllegalStateException("This tensor is no dimensionless. Dimensions: " + dimensions());
+ if (cells().size() != 1)
+ throw new IllegalStateException("This tensor does not have a single value, it has " + cells().size());
+ return cells().values().iterator().next();
+ }
// ----------------- Primitive tensor functions
@@ -63,6 +77,10 @@ public interface Tensor {
}
/** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
+ default Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) {
+ return new Reduce(new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate();
+ }
+ /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate();
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
new file mode 100644
index 00000000000..5df54ca9ced
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -0,0 +1,64 @@
+package com.yahoo.tensor;
+
+import com.yahoo.tensor.functions.Reduce;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Microbenchmark of tensor operations.
+ *
+ * @author bratseth
+ */
+public class TensorFunctionBenchmark {
+
+ private final Random random = new Random();
+
+ public void benchmark(int iterations) {
+ List<Tensor> modelVectors = generateVectors(100, 300);
+ Tensor queryVector = generateVectors(1, 300).get(0);
+ dotProduct(queryVector, modelVectors, 10); // warmup
+ long startTime = System.currentTimeMillis();
+ dotProduct(queryVector, modelVectors, iterations);
+ long totalTime = System.currentTimeMillis() - startTime;
+ System.out.println("Time per join: " + (totalTime / iterations) + " ms");
+ }
+
+ private double dotProduct(Tensor tensor, List<Tensor> tensors, int iterations) {
+ double result = 0;
+ for (int i = 0 ; i < iterations; i++)
+ result = dotProduct(tensor, tensors);
+ return result;
+ }
+
+ private double dotProduct(Tensor tensor, List<Tensor> tensors) {
+ double largest = Double.MIN_VALUE;
+ for (Tensor tensorElement : tensors) {
+ double dotProduct = tensor.join(tensorElement, (a, b) -> a * b).reduce(Reduce.Aggregator.sum).asDouble();
+ if (dotProduct > largest)
+ largest = dotProduct;
+ }
+ System.out.println(largest);
+ return largest;
+ }
+
+ private List<Tensor> generateVectors(int vectorCount, int vectorSize) {
+ List<Tensor> tensors = new ArrayList<>();
+ for (int i = 0; i < vectorCount; i++) {
+ MapTensorBuilder builder = new MapTensorBuilder();
+ builder.dimension("x");
+ for (int j = 0; j < vectorSize; j++) {
+ builder.cell().label("x", String.valueOf(j)).value(random.nextDouble());
+ }
+ tensors.add(builder.build());
+ }
+ return tensors;
+ }
+
+ public static void main(String[] args) {
+ // Was: 144 ms
+ new TensorFunctionBenchmark().benchmark(100);
+ }
+
+}