summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
commitf4203c3cc571722f08ee65047437c1290ed63f69 (patch)
tree7d06d17091a2e388e6771187a11cf4f4023a0c1e /vespajlib
parent316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff)
Allow bound functions in tensor generate
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java5
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java5
5 files changed, 22 insertions, 11 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 15cfce09793..8cba1ccdef8 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1620,7 +1620,8 @@
],
"methods": [
"public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)",
- "public void <init>(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)",
+ "public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)",
+ "public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index b8b644f8b49..a75e49c6402 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -73,6 +73,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
+ @SuppressWarnings("unchecked")
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
for (var cell : cells.entrySet())
@@ -114,6 +115,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
+ @SuppressWarnings("unchecked")
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
for (int i = 0; i < cells.size(); i++)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 28fc2c61426..52620814ecd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -27,28 +27,33 @@ public class Generate extends PrimitiveTensorFunction {
private final Function<List<Long>, Double> freeGenerator;
private final ScalarFunction boundGenerator;
+ /** The same as Generate.free */
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
+ this(type, Objects.requireNonNull(generator), null);
+ }
+
/**
- * Creates a generated tensor
+ * Creates a generated tensor from a free function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Long>, Double> generator) {
- this(type, Objects.requireNonNull(generator), null);
+ public static Generate free(TensorType type, Function<List<Long>, Double> generator) {
+ return new Generate(type, Objects.requireNonNull(generator), null);
}
/**
- * Creates a generated tensor
+ * Creates a generated tensor from a bound function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, ScalarFunction generator) {
- this(type, null, Objects.requireNonNull(generator));
+ public static Generate bound(TensorType type, ScalarFunction generator) {
+ return new Generate(type, null, Objects.requireNonNull(generator));
}
private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) {
@@ -127,6 +132,7 @@ public class Generate extends PrimitiveTensorFunction {
this.context = context;
}
+ @SuppressWarnings("unchecked")
double apply(IndexedTensor.Indexes indexes) {
if (freeGenerator != null) {
return freeGenerator.apply(indexes.toList());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
index c6a244b64df..70e08af16b6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.function.Function;
@@ -10,10 +11,10 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public interface ScalarFunction extends Function<EvaluationContext<?>, Double> {
+public interface ScalarFunction<NAMETYPE extends TypeContext.Name> extends Function<EvaluationContext<NAMETYPE>, Double> {
@Override
- Double apply(EvaluationContext<?> context);
+ Double apply(EvaluationContext<NAMETYPE> context);
default String toString(ToStringContext context) {
return toString();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index 925da9d3c89..e1ae7f13c48 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import org.junit.Test;
import java.util.Collections;
@@ -34,14 +35,14 @@ public class DynamicTensorTestCase {
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
- private static class Constant implements ScalarFunction {
+ private static class Constant implements ScalarFunction<TypeContext.Name> {
private final double value;
public Constant(double value) { this.value = value; }
@Override
- public Double apply(EvaluationContext<?> evaluationContext) { return value; }
+ public Double apply(EvaluationContext<TypeContext.Name> evaluationContext) { return value; }
@Override
public String toString() { return String.valueOf(value); }