diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
commit | f4203c3cc571722f08ee65047437c1290ed63f69 (patch) | |
tree | 7d06d17091a2e388e6771187a11cf4f4023a0c1e /vespajlib/src/main/java | |
parent | 316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff) |
Allow bound functions in tensor generate
Diffstat (limited to 'vespajlib/src/main/java')
3 files changed, 17 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index b8b644f8b49..a75e49c6402 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -73,6 +73,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); for (var cell : cells.entrySet()) @@ -114,6 +115,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); for (int i = 0; i < cells.size(); i++) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 28fc2c61426..52620814ecd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -27,28 +27,33 @@ public class Generate extends PrimitiveTensorFunction { private final Function<List<Long>, Double> freeGenerator; private final ScalarFunction boundGenerator; + /** The same as Generate.free */ + public Generate(TensorType type, Function<List<Long>, Double> generator) { + this(type, Objects.requireNonNull(generator), null); + } + /** - * Creates a generated tensor + * Creates a generated tensor from a free function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function<List<Long>, Double> generator) { - this(type, Objects.requireNonNull(generator), null); + public static Generate free(TensorType type, Function<List<Long>, Double> generator) { + return new Generate(type, Objects.requireNonNull(generator), null); } /** - * Creates a generated tensor + * Creates a generated tensor from a bound function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, ScalarFunction generator) { - this(type, null, Objects.requireNonNull(generator)); + public static Generate bound(TensorType type, ScalarFunction generator) { + return new Generate(type, null, Objects.requireNonNull(generator)); } private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) { @@ -127,6 +132,7 @@ public class Generate extends PrimitiveTensorFunction { this.context = context; } + @SuppressWarnings("unchecked") double apply(IndexedTensor.Indexes indexes) { if (freeGenerator != null) { return freeGenerator.apply(indexes.toList()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index c6a244b64df..70e08af16b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.function.Function; @@ -10,10 +11,10 @@ import java.util.function.Function; * * @author bratseth */ -public interface ScalarFunction extends Function<EvaluationContext<?>, Double> { +public interface ScalarFunction<NAMETYPE extends TypeContext.Name> extends Function<EvaluationContext<NAMETYPE>, Double> { @Override - Double apply(EvaluationContext<?> context); + Double apply(EvaluationContext<NAMETYPE> context); default String toString(ToStringContext context) { return toString(); |