From 494f5d0e133417880881b05c4fe4f08a265a7510 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Mon, 20 Mar 2023 10:41:16 +0000 Subject: add withTransformedFunctions() to TensorFunction API --- vespajlib/abi-spec.json | 4 +++- .../com/yahoo/tensor/functions/DynamicTensor.java | 24 ++++++++++++++++++++++ .../java/com/yahoo/tensor/functions/Slice.java | 17 +++++++++++++++ .../com/yahoo/tensor/functions/TensorFunction.java | 6 ++++++ 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e45b13a6eb0..88872fef8a1 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2765,6 +2765,7 @@ "public void (com.yahoo.tensor.functions.TensorFunction, java.util.List)", "public java.util.List arguments()", "public java.util.List selectorFunctions()", + "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)", "public com.yahoo.tensor.functions.Slice withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", @@ -2810,7 +2811,8 @@ "public abstract java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public java.util.Optional asScalarFunction()", "public java.lang.String toString()", - "public abstract int hashCode()" + "public abstract int hashCode()", + "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)" ], "fields" : [ ] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 61d3acf6338..630eeb81d13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -12,8 +12,10 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.List; +import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; +import java.util.function.Function; /** * A function which is a tensor whose values are computed by individual lambda functions on evaluation. @@ -82,6 +84,17 @@ public abstract class DynamicTensor extends PrimitiveTens return result; } + public TensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + Map> transformedCells = new LinkedHashMap<>(); + for (var orig : cells.entrySet()) { + var transformed = transformer.apply(orig.getValue()); + transformedCells.put(orig.getKey(), transformed); + } + return new MappedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type()); @@ -134,6 +147,17 @@ public abstract class DynamicTensor extends PrimitiveTens return result; } + public TensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + List> transformedCells = new ArrayList<>(); + for (var orig : cells) { + var transformed = transformer.apply(orig); + transformedCells.add(transformed); + } + return new IndexedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 87e24306031..066d75bcd9c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -16,6 +16,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -56,6 +57,22 @@ public class Slice extends PrimitiveTensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + List> transformedAddress = new ArrayList<>(); + for (var orig : subspaceAddress) { + var idxFun = orig.index(); + if (idxFun.isPresent()) { + var transformed = transformer.apply(idxFun.get()); + transformedAddress.add(new DimensionValue(orig.dimension(), transformed)); + } else { + transformedAddress.add(orig); + } + } + return new Slice<>(argument, transformedAddress); + } + @Override public Slice withArguments(List> arguments) { if (arguments.size() != 1) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 503f414d8eb..bf5eaeb6c2e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; import java.util.Optional; +import java.util.function.Function; /** * A representation of a tensor function which is able to be translated to a set of primitive @@ -72,4 +73,9 @@ public abstract class TensorFunction { @Override public abstract int hashCode(); + public TensorFunction withTransformedFunctions( + Function, ScalarFunction> transformer) + { + return this; + } } -- cgit v1.2.3