summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-26 14:16:52 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-26 14:16:52 +0200
commit316c941e90f39d2e9bc46f12b96ca0f87471d1bd (patch)
tree63af0d60f937dda39bbe55793cc1a8253e89f134 /vespajlib
parentb85ad56773d5bb675dd0c7fc437bcf130cb8a15d (diff)
Allow bound functions in tensor Generate
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java94
2 files changed, 90 insertions, 5 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 307d5df3d6c..15cfce09793 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1620,6 +1620,7 @@
],
"methods": [
"public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)",
+ "public void <init>(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 83cba3479e2..28fc2c61426 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
import java.util.function.Function;
/**
@@ -21,7 +22,10 @@ import java.util.function.Function;
public class Generate extends PrimitiveTensorFunction {
private final TensorType type;
- private final Function<List<Long>, Double> generator;
+
+ // One of these are null
+ private final Function<List<Long>, Double> freeGenerator;
+ private final ScalarFunction boundGenerator;
/**
* Creates a generated tensor
@@ -32,11 +36,27 @@ public class Generate extends PrimitiveTensorFunction {
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
public Generate(TensorType type, Function<List<Long>, Double> generator) {
+ this(type, Objects.requireNonNull(generator), null);
+ }
+
+ /**
+ * Creates a generated tensor
+ *
+ * @param type the type of the tensor
+ * @param generator the function generating values from a list of numbers specifying the indexes of the
+ * tensor cell which will receive the value
+ * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
+ */
+ public Generate(TensorType type, ScalarFunction generator) {
+ this(type, null, Objects.requireNonNull(generator));
+ }
+
+ private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
- Objects.requireNonNull(generator, "The argument function cannot be null");
validateType(type);
this.type = type;
- this.generator = generator;
+ this.freeGenerator = freeGenerator;
+ this.boundGenerator = boundGenerator;
}
private void validateType(TensorType type) {
@@ -65,9 +85,10 @@ public class Generate extends PrimitiveTensorFunction {
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
+ GenerateContext<NAMETYPE> generateContext = new GenerateContext<>(type, context);
for (int i = 0; i < indexes.size(); i++) {
indexes.next();
- builder.cell(generator.apply(indexes.toList()), indexes.indexesForReading());
+ builder.cell(generateContext.apply(indexes), indexes.indexesForReading());
}
return builder.build();
}
@@ -80,6 +101,69 @@ public class Generate extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
+ public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; }
+
+ private String generatorToString(ToStringContext context) {
+ if (freeGenerator != null)
+ return freeGenerator.toString();
+ else
+ return boundGenerator.toString(context);
+ }
+
+ /**
+ * A context for generating all the values of a tensor produced by evaluating Generate.
+ * This returns all the current index values as variables and falls back to delivering from the given
+ * evaluation context.
+ */
+ private class GenerateContext<NAMETYPE extends TypeContext.Name> implements EvaluationContext<NAMETYPE> {
+
+ private final TensorType type;
+ private final EvaluationContext<NAMETYPE> context;
+
+ private IndexedTensor.Indexes indexes;
+
+ GenerateContext(TensorType type, EvaluationContext<NAMETYPE> context) {
+ this.type = type;
+ this.context = context;
+ }
+
+ double apply(IndexedTensor.Indexes indexes) {
+ if (freeGenerator != null) {
+ return freeGenerator.apply(indexes.toList());
+ }
+ else {
+ this.indexes = indexes;
+ return boundGenerator.apply(this);
+ }
+ }
+
+ @Override
+ public Tensor getTensor(String name) {
+ Optional<Integer> index = type.indexOfDimension(name);
+ if (index.isPresent()) // this is the name of a dimension
+ return Tensor.from(indexes.indexesForReading()[index.get()]);
+ else
+ return context.getTensor(name);
+ }
+
+ @Override
+ public TensorType getType(NAMETYPE name) {
+ Optional<Integer> index = type.indexOfDimension(name.name());
+ if (index.isPresent()) // this is the name of a dimension
+ return TensorType.empty;
+ else
+ return context.getType(name);
+ }
+
+ @Override
+ public TensorType getType(String name) {
+ Optional<Integer> index = type.indexOfDimension(name);
+ if (index.isPresent()) // this is the name of a dimension
+ return TensorType.empty;
+ else
+ return context.getType(name);
+ }
+
+ }
}