diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
commit | f4203c3cc571722f08ee65047437c1290ed63f69 (patch) | |
tree | 7d06d17091a2e388e6771187a11cf4f4023a0c1e /vespajlib | |
parent | 316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff) |
Allow bound functions in tensor generate
Diffstat (limited to 'vespajlib')
5 files changed, 22 insertions, 11 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 15cfce09793..8cba1ccdef8 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1620,7 +1620,8 @@ ], "methods": [ "public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)", - "public void <init>(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)", + "public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)", + "public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index b8b644f8b49..a75e49c6402 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -73,6 +73,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); for (var cell : cells.entrySet()) @@ -114,6 +115,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); for (int i = 0; i < cells.size(); i++) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 28fc2c61426..52620814ecd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -27,28 +27,33 @@ public class Generate extends PrimitiveTensorFunction { private final Function<List<Long>, Double> freeGenerator; private final ScalarFunction boundGenerator; + /** The same as Generate.free */ + public Generate(TensorType type, Function<List<Long>, Double> generator) { + this(type, Objects.requireNonNull(generator), null); + } + /** - * Creates a generated tensor + * Creates a generated tensor from a free function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function<List<Long>, Double> generator) { - this(type, Objects.requireNonNull(generator), null); + public static Generate free(TensorType type, Function<List<Long>, Double> generator) { + return new Generate(type, Objects.requireNonNull(generator), null); } /** - * Creates a generated tensor + * Creates a generated tensor from a bound function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, ScalarFunction generator) { - this(type, null, Objects.requireNonNull(generator)); + public static Generate bound(TensorType type, ScalarFunction generator) { + return new Generate(type, null, Objects.requireNonNull(generator)); } private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) { @@ -127,6 +132,7 @@ public class Generate extends PrimitiveTensorFunction { this.context = context; } + @SuppressWarnings("unchecked") double apply(IndexedTensor.Indexes indexes) { if (freeGenerator != null) { return freeGenerator.apply(indexes.toList()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index c6a244b64df..70e08af16b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.function.Function; @@ -10,10 +11,10 @@ import java.util.function.Function; * * @author bratseth */ -public interface ScalarFunction extends Function<EvaluationContext<?>, Double> { +public interface ScalarFunction<NAMETYPE extends TypeContext.Name> extends Function<EvaluationContext<NAMETYPE>, Double> { @Override - Double apply(EvaluationContext<?> context); + Double apply(EvaluationContext<NAMETYPE> context); default String toString(ToStringContext context) { return toString(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 925da9d3c89..e1ae7f13c48 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import org.junit.Test; import java.util.Collections; @@ -34,14 +35,14 @@ public class DynamicTensorTestCase { assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } - private static class Constant implements ScalarFunction { + private static class Constant implements ScalarFunction<TypeContext.Name> { private final double value; public Constant(double value) { this.value = value; } @Override - public Double apply(EvaluationContext<?> evaluationContext) { return value; } + public Double apply(EvaluationContext<TypeContext.Name> evaluationContext) { return value; } @Override public String toString() { return String.valueOf(value); } |