summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java53
1 files changed, 53 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
new file mode 100644
index 00000000000..ba34c0d9748
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -0,0 +1,53 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.CompositeTensorFunction;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * A tensor generator which returns a tensor of any dimension filled with random numbers between 0 and 1.
+ *
+ * @author bratseth
+ */
+public class Random extends CompositeTensorFunction {
+
+ private final TensorType type;
+
+ public Random(TensorType type) {
+ this.type = type;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Generate(type, ScalarFunctions.random());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
+ }
+
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::toString);
+ }
+
+}