diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-27 15:58:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-27 15:58:06 +0200 |
commit | 77bb8f5117b7a0f78b2dc99a3937430339e4291d (patch) | |
tree | 9037b54f17e3175a8d11e1b43b55b71887f867a4 /vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java | |
parent | f4203c3cc571722f08ee65047437c1290ed63f69 (diff) |
Support index generating expressions in tensor value functions
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 52620814ecd..aaed607aaa1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -19,13 +19,13 @@ import java.util.function.Function; * * @author bratseth */ -public class Generate extends PrimitiveTensorFunction { +public class Generate<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { private final TensorType type; // One of these are null private final Function<List<Long>, Double> freeGenerator; - private final ScalarFunction boundGenerator; + private final ScalarFunction<NAMETYPE> boundGenerator; /** The same as Generate.free */ public Generate(TensorType type, Function<List<Long>, Double> generator) { @@ -40,8 +40,8 @@ public class Generate extends PrimitiveTensorFunction { * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public static Generate free(TensorType type, Function<List<Long>, Double> generator) { - return new Generate(type, Objects.requireNonNull(generator), null); + public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> free(TensorType type, Function<List<Long>, Double> generator) { + return new Generate<>(type, Objects.requireNonNull(generator), null); } /** @@ -52,11 +52,11 @@ public class Generate extends PrimitiveTensorFunction { * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public static Generate bound(TensorType type, ScalarFunction generator) { - return new Generate(type, null, Objects.requireNonNull(generator)); + public static <NAMETYPE extends TypeContext.Name> Generate<NAMETYPE> bound(TensorType type, ScalarFunction<NAMETYPE> generator) { + return new Generate<>(type, null, Objects.requireNonNull(generator)); } - private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) { + private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction<NAMETYPE> boundGenerator) { Objects.requireNonNull(type, "The argument tensor type cannot be null"); validateType(type); this.type = type; @@ -71,26 +71,26 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + public TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); - GenerateContext<NAMETYPE> generateContext = new GenerateContext<>(type, context); + GenerateContext generateContext = new GenerateContext(type, context); for (int i = 0; i < indexes.size(); i++) { indexes.next(); builder.cell(generateContext.apply(indexes), indexes.indexesForReading()); @@ -120,7 +120,7 @@ public class Generate extends PrimitiveTensorFunction { * This returns all the current index values as variables and falls back to delivering from the given * evaluation context. */ - private class GenerateContext<NAMETYPE extends TypeContext.Name> implements EvaluationContext<NAMETYPE> { + private class GenerateContext implements EvaluationContext<NAMETYPE> { private final TensorType type; private final EvaluationContext<NAMETYPE> context; |