aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
commitf4203c3cc571722f08ee65047437c1290ed63f69 (patch)
tree7d06d17091a2e388e6771187a11cf4f4023a0c1e /vespajlib/src/main/java
parent316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff)
Allow bound functions in tensor generate
Diffstat (limited to 'vespajlib/src/main/java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java5
3 files changed, 17 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index b8b644f8b49..a75e49c6402 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -73,6 +73,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
+ @SuppressWarnings("unchecked")
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
for (var cell : cells.entrySet())
@@ -114,6 +115,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
+ @SuppressWarnings("unchecked")
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
for (int i = 0; i < cells.size(); i++)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 28fc2c61426..52620814ecd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -27,28 +27,33 @@ public class Generate extends PrimitiveTensorFunction {
private final Function<List<Long>, Double> freeGenerator;
private final ScalarFunction boundGenerator;
+ /** The same as Generate.free */
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
+ this(type, Objects.requireNonNull(generator), null);
+ }
+
/**
- * Creates a generated tensor
+ * Creates a generated tensor from a free function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Long>, Double> generator) {
- this(type, Objects.requireNonNull(generator), null);
+ public static Generate free(TensorType type, Function<List<Long>, Double> generator) {
+ return new Generate(type, Objects.requireNonNull(generator), null);
}
/**
- * Creates a generated tensor
+ * Creates a generated tensor from a bound function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, ScalarFunction generator) {
- this(type, null, Objects.requireNonNull(generator));
+ public static Generate bound(TensorType type, ScalarFunction generator) {
+ return new Generate(type, null, Objects.requireNonNull(generator));
}
private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) {
@@ -127,6 +132,7 @@ public class Generate extends PrimitiveTensorFunction {
this.context = context;
}
+ @SuppressWarnings("unchecked")
double apply(IndexedTensor.Indexes indexes) {
if (freeGenerator != null) {
return freeGenerator.apply(indexes.toList());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
index c6a244b64df..70e08af16b6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.function.Function;
@@ -10,10 +11,10 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public interface ScalarFunction extends Function<EvaluationContext<?>, Double> {
+public interface ScalarFunction<NAMETYPE extends TypeContext.Name> extends Function<EvaluationContext<NAMETYPE>, Double> {
@Override
- Double apply(EvaluationContext<?> context);
+ Double apply(EvaluationContext<NAMETYPE> context);
default String toString(ToStringContext context) {
return toString();