diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-20 10:41:16 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-20 12:22:52 +0000 |
commit | 494f5d0e133417880881b05c4fe4f08a265a7510 (patch) | |
tree | e42ac6b5231ef5ad6182de3309b7d44e42d69fb5 | |
parent | 829f4a279b817b01144c5896de1a5c671804857d (diff) |
add withTransformedFunctions() to TensorFunction API
4 files changed, 50 insertions, 1 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e45b13a6eb0..88872fef8a1 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2765,6 +2765,7 @@ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", "public java.util.List arguments()", "public java.util.List selectorFunctions()", + "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)", "public com.yahoo.tensor.functions.Slice withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", @@ -2810,7 +2811,8 @@ "public abstract java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public java.util.Optional asScalarFunction()", "public java.lang.String toString()", - "public abstract int hashCode()" + "public abstract int hashCode()", + "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)" ], "fields" : [ ] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 61d3acf6338..630eeb81d13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -12,8 +12,10 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.List; +import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; +import java.util.function.Function; /** * A function which is a tensor whose values are computed by individual lambda functions on evaluation. @@ -82,6 +84,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens return result; } + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + Map<TensorAddress, ScalarFunction<NAMETYPE>> transformedCells = new LinkedHashMap<>(); + for (var orig : cells.entrySet()) { + var transformed = transformer.apply(orig.getValue()); + transformedCells.put(orig.getKey(), transformed); + } + return new MappedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); @@ -134,6 +147,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens return result; } + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + List<ScalarFunction<NAMETYPE>> transformedCells = new ArrayList<>(); + for (var orig : cells) { + var transformed = transformer.apply(orig); + transformedCells.add(transformed); + } + return new IndexedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 87e24306031..066d75bcd9c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -16,6 +16,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -56,6 +57,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY return result; } + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + List<DimensionValue<NAMETYPE>> transformedAddress = new ArrayList<>(); + for (var orig : subspaceAddress) { + var idxFun = orig.index(); + if (idxFun.isPresent()) { + var transformed = transformer.apply(idxFun.get()); + transformedAddress.add(new DimensionValue<NAMETYPE>(orig.dimension(), transformed)); + } else { + transformedAddress.add(orig); + } + } + return new Slice<>(argument, transformedAddress); + } + @Override public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 1) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 503f414d8eb..bf5eaeb6c2e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; import java.util.Optional; +import java.util.function.Function; /** * A representation of a tensor function which is able to be translated to a set of primitive @@ -72,4 +73,9 @@ public abstract class TensorFunction<NAMETYPE extends Name> { @Override public abstract int hashCode(); + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + return this; + } } |