From f4203c3cc571722f08ee65047437c1290ed63f69 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 26 Nov 2019 16:51:50 +0200 Subject: Allow bound functions in tensor generate --- .../java/com/yahoo/tensor/functions/DynamicTensor.java | 2 ++ .../main/java/com/yahoo/tensor/functions/Generate.java | 18 ++++++++++++------ .../com/yahoo/tensor/functions/ScalarFunction.java | 5 +++-- 3 files changed, 17 insertions(+), 8 deletions(-) (limited to 'vespajlib/src/main/java') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index b8b644f8b49..a75e49c6402 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -73,6 +73,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type()); for (var cell : cells.entrySet()) @@ -114,6 +115,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public Tensor evaluate(EvaluationContext context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); for (int i = 0; i < cells.size(); i++) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 28fc2c61426..52620814ecd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -27,28 +27,33 @@ public class Generate extends PrimitiveTensorFunction { private final Function, Double> freeGenerator; private final ScalarFunction boundGenerator; + /** The same as Generate.free */ + public Generate(TensorType type, Function, Double> generator) { + this(type, Objects.requireNonNull(generator), null); + } + /** - * Creates a generated tensor + * Creates a generated tensor from a free function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function, Double> generator) { - this(type, Objects.requireNonNull(generator), null); + public static Generate free(TensorType type, Function, Double> generator) { + return new Generate(type, Objects.requireNonNull(generator), null); } /** - * Creates a generated tensor + * Creates a generated tensor from a bound function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, ScalarFunction generator) { - this(type, null, Objects.requireNonNull(generator)); + public static Generate bound(TensorType type, ScalarFunction generator) { + return new Generate(type, null, Objects.requireNonNull(generator)); } private Generate(TensorType type, Function, Double> freeGenerator, ScalarFunction boundGenerator) { @@ -127,6 +132,7 @@ public class Generate extends PrimitiveTensorFunction { this.context = context; } + @SuppressWarnings("unchecked") double apply(IndexedTensor.Indexes indexes) { if (freeGenerator != null) { return freeGenerator.apply(indexes.toList()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index c6a244b64df..70e08af16b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.function.Function; @@ -10,10 +11,10 @@ import java.util.function.Function; * * @author bratseth */ -public interface ScalarFunction extends Function, Double> { +public interface ScalarFunction extends Function, Double> { @Override - Double apply(EvaluationContext context); + Double apply(EvaluationContext context); default String toString(ToStringContext context) { return toString(); -- cgit v1.2.3