From 494f5d0e133417880881b05c4fe4f08a265a7510 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Mon, 20 Mar 2023 10:41:16 +0000 Subject: add withTransformedFunctions() to TensorFunction API --- .../com/yahoo/tensor/functions/DynamicTensor.java | 24 ++++++++++++++++++++++ .../java/com/yahoo/tensor/functions/Slice.java | 17 +++++++++++++++ .../com/yahoo/tensor/functions/TensorFunction.java | 6 ++++++ 3 files changed, 47 insertions(+) (limited to 'vespajlib/src') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 61d3acf6338..630eeb81d13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -12,8 +12,10 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.List; +import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; +import java.util.function.Function; /** * A function which is a tensor whose values are computed by individual lambda functions on evaluation. @@ -82,6 +84,17 @@ public abstract class DynamicTensor extends PrimitiveTens return result; } + public TensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + Map> transformedCells = new LinkedHashMap<>(); + for (var orig : cells.entrySet()) { + var transformed = transformer.apply(orig.getValue()); + transformedCells.put(orig.getKey(), transformed); + } + return new MappedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type()); @@ -134,6 +147,17 @@ public abstract class DynamicTensor extends PrimitiveTens return result; } + public TensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + List> transformedCells = new ArrayList<>(); + for (var orig : cells) { + var transformed = transformer.apply(orig); + transformedCells.add(transformed); + } + return new IndexedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 87e24306031..066d75bcd9c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -16,6 +16,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -56,6 +57,22 @@ public class Slice extends PrimitiveTensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + List> transformedAddress = new ArrayList<>(); + for (var orig : subspaceAddress) { + var idxFun = orig.index(); + if (idxFun.isPresent()) { + var transformed = transformer.apply(idxFun.get()); + transformedAddress.add(new DimensionValue(orig.dimension(), transformed)); + } else { + transformedAddress.add(orig); + } + } + return new Slice<>(argument, transformedAddress); + } + @Override public Slice withArguments(List> arguments) { if (arguments.size() != 1) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 503f414d8eb..bf5eaeb6c2e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; import java.util.Optional; +import java.util.function.Function; /** * A representation of a tensor function which is able to be translated to a set of primitive @@ -72,4 +73,9 @@ public abstract class TensorFunction { @Override public abstract int hashCode(); + public TensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + return this; + } } -- cgit v1.2.3