diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 14:16:52 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 14:16:52 +0200 |
commit | 316c941e90f39d2e9bc46f12b96ca0f87471d1bd (patch) | |
tree | 63af0d60f937dda39bbe55793cc1a8253e89f134 /vespajlib | |
parent | b85ad56773d5bb675dd0c7fc437bcf130cb8a15d (diff) |
Allow bound functions in tensor Generate
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/abi-spec.json | 1 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java | 94 |
2 files changed, 90 insertions, 5 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 307d5df3d6c..15cfce09793 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1620,6 +1620,7 @@ ], "methods": [ "public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)", + "public void <init>(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 83cba3479e2..28fc2c61426 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.function.Function; /** @@ -21,7 +22,10 @@ import java.util.function.Function; public class Generate extends PrimitiveTensorFunction { private final TensorType type; - private final Function<List<Long>, Double> generator; + + // One of these are null + private final Function<List<Long>, Double> freeGenerator; + private final ScalarFunction boundGenerator; /** * Creates a generated tensor @@ -32,11 +36,27 @@ public class Generate extends PrimitiveTensorFunction { * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ public Generate(TensorType type, Function<List<Long>, Double> generator) { + this(type, Objects.requireNonNull(generator), null); + } + + /** + * Creates a generated tensor + * + * @param type the type of the tensor + * @param generator the function generating values from a list of numbers specifying the indexes of the + * tensor cell which will receive the value + * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound + */ + public Generate(TensorType type, ScalarFunction generator) { + this(type, null, Objects.requireNonNull(generator)); + } + + private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) { Objects.requireNonNull(type, "The argument tensor type cannot be null"); - Objects.requireNonNull(generator, "The argument function cannot be null"); validateType(type); this.type = type; - this.generator = generator; + this.freeGenerator = freeGenerator; + this.boundGenerator = boundGenerator; } private void validateType(TensorType type) { @@ -65,9 +85,10 @@ public class Generate extends PrimitiveTensorFunction { public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); + GenerateContext<NAMETYPE> generateContext = new GenerateContext<>(type, context); for (int i = 0; i < indexes.size(); i++) { indexes.next(); - builder.cell(generator.apply(indexes.toList()), indexes.indexesForReading()); + builder.cell(generateContext.apply(indexes), indexes.indexesForReading()); } return builder.build(); } @@ -80,6 +101,69 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { return type + "(" + generator + ")"; } + public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; } + + private String generatorToString(ToStringContext context) { + if (freeGenerator != null) + return freeGenerator.toString(); + else + return boundGenerator.toString(context); + } + + /** + * A context for generating all the values of a tensor produced by evaluating Generate. + * This returns all the current index values as variables and falls back to delivering from the given + * evaluation context. + */ + private class GenerateContext<NAMETYPE extends TypeContext.Name> implements EvaluationContext<NAMETYPE> { + + private final TensorType type; + private final EvaluationContext<NAMETYPE> context; + + private IndexedTensor.Indexes indexes; + + GenerateContext(TensorType type, EvaluationContext<NAMETYPE> context) { + this.type = type; + this.context = context; + } + + double apply(IndexedTensor.Indexes indexes) { + if (freeGenerator != null) { + return freeGenerator.apply(indexes.toList()); + } + else { + this.indexes = indexes; + return boundGenerator.apply(this); + } + } + + @Override + public Tensor getTensor(String name) { + Optional<Integer> index = type.indexOfDimension(name); + if (index.isPresent()) // this is the name of a dimension + return Tensor.from(indexes.indexesForReading()[index.get()]); + else + return context.getTensor(name); + } + + @Override + public TensorType getType(NAMETYPE name) { + Optional<Integer> index = type.indexOfDimension(name.name()); + if (index.isPresent()) // this is the name of a dimension + return TensorType.empty; + else + return context.getType(name); + } + + @Override + public TensorType getType(String name) { + Optional<Integer> index = type.indexOfDimension(name); + if (index.isPresent()) // this is the name of a dimension + return TensorType.empty; + else + return context.getType(name); + } + + } } |