summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 11:12:53 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 11:12:53 +0100
commit39a652ce439a42fb8db372c821c834d02c95b0f1 (patch)
treeb246abdbddfcf8fd84f2f74590aeb56743f2dfa3 /vespajlib
parented8ec5305f6838e31de94ef87ddd3a75390b59ed (diff)
Tensor generate functions
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java56
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java53
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java51
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java82
5 files changed, 243 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index c1a24abd878..e99e7da7415 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -443,7 +443,7 @@ public class IndexedTensor implements Tensor {
public boolean equals(Object o) {
if (o == this) return true;
if ( ! ( o instanceof Map.Entry)) return false;
- Map.Entry other = (Map.Entry)o;
+ Map.Entry<?,?> other = (Map.Entry)o;
if ( ! this.getValue().equals(other.getValue())) return false;
if ( ! this.getKey().equals(other.getKey())) return false;
return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
new file mode 100644
index 00000000000..0bb92bc2a6f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -0,0 +1,56 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.CompositeTensorFunction;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere.
+ *
+ * @author bratseth
+ */
+public class Diag extends CompositeTensorFunction {
+
+ private final TensorType type;
+ private final Function<List<Integer>, Double> diagFunction;
+
+ public Diag(TensorType type) {
+ this.type = type;
+ this.diagFunction = ScalarFunctions.equalArguments(dimensionNames().collect(Collectors.toList()));
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Generate(type, diagFunction);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
+ }
+
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::toString);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
new file mode 100644
index 00000000000..ba34c0d9748
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -0,0 +1,53 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.CompositeTensorFunction;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * A tensor generator which returns a tensor of any dimension filled with random numbers between 0 and 1.
+ *
+ * @author bratseth
+ */
+public class Random extends CompositeTensorFunction {
+
+ private final TensorType type;
+
+ public Random(TensorType type) {
+ this.type = type;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Generate(type, ScalarFunctions.random());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
+ }
+
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::toString);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
new file mode 100644
index 00000000000..e18edd48127
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -0,0 +1,51 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor
+ * indexes of each position.
+ *
+ * @author bratseth
+ */
+public class Range extends CompositeTensorFunction {
+
+ private final TensorType type;
+ private final Function<List<Integer>, Double> rangeFunction;
+
+ public Range(TensorType type) {
+ this.type = type;
+ this.rangeFunction = ScalarFunctions.sumArguments(dimensionNames().collect(Collectors.toList()));
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Generate(type, rangeFunction);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
+ }
+
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::toString);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index c1b1cb2243d..a0b60f53df3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -1,9 +1,15 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import java.util.List;
+import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleSupplier;
import java.util.function.DoubleUnaryOperator;
+import java.util.function.Function;
+import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
@@ -21,6 +27,13 @@ public class ScalarFunctions {
public static DoubleUnaryOperator square() { return new Square(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator exp() { return new Exponent(); }
+ public static Function<List<Integer>, Double> random() { return new Random(); }
+ public static Function<List<Integer>, Double> equalArguments(List<String> argumentNames) {
+ return new EqualArguments(argumentNames);
+ }
+ public static Function<List<Integer>, Double> sumArguments(List<String> argumentNames) {
+ return new SumArguments(argumentNames);
+ }
public static class Addition implements DoubleBinaryOperator {
@@ -81,4 +94,73 @@ public class ScalarFunctions {
}
+ public static class Random implements Function<List<Integer>, Double> {
+
+ @Override
+ public Double apply(List<Integer> values) {
+ return ThreadLocalRandom.current().nextDouble();
+ }
+
+ @Override
+ public String toString() { return "random()"; }
+
+ }
+
+ public static class EqualArguments implements Function<List<Integer>, Double> {
+
+ private final ImmutableList<String> argumentNames;
+
+ private EqualArguments(List<String> argumentNames) {
+ this.argumentNames = ImmutableList.copyOf(argumentNames);
+ }
+
+ @Override
+ public Double apply(List<Integer> values) {
+ if (values.isEmpty()) return 1.0;
+ for (Integer value : values)
+ if ( ! value.equals(values.get(0)))
+ return 0.0;
+ return 1.0;
+ }
+
+ @Override
+ public String toString() {
+ if (argumentNames.size() == 0) return "(1)";
+ if (argumentNames.size() == 1) return "(1)";
+ if (argumentNames.size() == 2) return "(" + argumentNames.get(0) + "==" + argumentNames.get(1) + ")";
+
+ StringBuilder b = new StringBuilder("(");
+ for (int i = 0; i < argumentNames.size() -1; i++) {
+ b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
+ if ( i < argumentNames.size() -2)
+ b.append("*");
+ }
+ return b.toString();
+ }
+
+ }
+
+ public static class SumArguments implements Function<List<Integer>, Double> {
+
+ private final ImmutableList<String> argumentNames;
+
+ private SumArguments(List<String> argumentNames) {
+ this.argumentNames = ImmutableList.copyOf(argumentNames);
+ }
+
+ @Override
+ public Double apply(List<Integer> values) {
+ int sum = 0;
+ for (Integer value : values)
+ sum += value;
+ return (double)sum;
+ }
+
+ @Override
+ public String toString() {
+ return "(" + argumentNames.stream().collect(Collectors.joining("+")) + ")";
+ }
+
+ }
+
}