aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java24
1 files changed, 24 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 61d3acf6338..630eeb81d13 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -12,8 +12,10 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.List;
+import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.function.Function;
/**
* A function which is a tensor whose values are computed by individual lambda functions on evaluation.
@@ -82,6 +84,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ Map<TensorAddress, ScalarFunction<NAMETYPE>> transformedCells = new LinkedHashMap<>();
+ for (var orig : cells.entrySet()) {
+ var transformed = transformer.apply(orig.getValue());
+ transformedCells.put(orig.getKey(), transformed);
+ }
+ return new MappedDynamicTensor<>(type(), transformedCells);
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
@@ -134,6 +147,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ List<ScalarFunction<NAMETYPE>> transformedCells = new ArrayList<>();
+ for (var orig : cells) {
+ var transformed = transformer.apply(orig);
+ transformedCells.add(transformed);
+ }
+ return new IndexedDynamicTensor<>(type(), transformedCells);
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());