diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java new file mode 100644 index 00000000000..ba34c0d9748 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -0,0 +1,53 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.CompositeTensorFunction; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A tensor generator which returns a tensor of any dimension filled with random numbers between 0 and 1. + * + * @author bratseth + */ +public class Random extends CompositeTensorFunction { + + private final TensorType type; + + public Random(TensorType type) { + this.type = type; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Generate(type, ScalarFunctions.random()); + } + + @Override + public String toString(ToStringContext context) { + return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; + } + + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::toString); + } + +} |