aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java28
1 files changed, 14 insertions, 14 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 52620814ecd..aaed607aaa1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -19,13 +19,13 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public class Generate extends PrimitiveTensorFunction {
+public class Generate<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final TensorType type;
// One of these are null
private final Function<List<Long>, Double> freeGenerator;
- private final ScalarFunction boundGenerator;
+ private final ScalarFunction<NAMETYPE> boundGenerator;
/** The same as Generate.free */
public Generate(TensorType type, Function<List<Long>, Double> generator) {
@@ -40,8 +40,8 @@ public class Generate extends PrimitiveTensorFunction {
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public static Generate free(TensorType type, Function<List<Long>, Double> generator) {
- return new Generate(type, Objects.requireNonNull(generator), null);
+ public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> free(TensorType type, Function<List<Long>, Double> generator) {
+ return new Generate<>(type, Objects.requireNonNull(generator), null);
}
/**
@@ -52,11 +52,11 @@ public class Generate extends PrimitiveTensorFunction {
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public static Generate bound(TensorType type, ScalarFunction generator) {
- return new Generate(type, null, Objects.requireNonNull(generator));
+ public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> bound(TensorType type, ScalarFunction<NAMETYPE> generator) {
+ return new Generate<>(type, null, Objects.requireNonNull(generator));
}
- private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) {
+ private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction<NAMETYPE> boundGenerator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
validateType(type);
this.type = type;
@@ -71,26 +71,26 @@ public class Generate extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
+ public TensorType type(TypeContext<NAMETYPE> context) { return type; }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
- GenerateContext<NAMETYPE> generateContext = new GenerateContext<>(type, context);
+ GenerateContext generateContext = new GenerateContext(type, context);
for (int i = 0; i < indexes.size(); i++) {
indexes.next();
builder.cell(generateContext.apply(indexes), indexes.indexesForReading());
@@ -120,7 +120,7 @@ public class Generate extends PrimitiveTensorFunction {
* This returns all the current index values as variables and falls back to delivering from the given
* evaluation context.
*/
- private class GenerateContext<NAMETYPE extends TypeContext.Name> implements EvaluationContext<NAMETYPE> {
+ private class GenerateContext implements EvaluationContext<NAMETYPE> {
private final TensorType type;
private final EvaluationContext<NAMETYPE> context;