summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-20 10:41:16 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-20 12:22:52 +0000
commit494f5d0e133417880881b05c4fe4f08a265a7510 (patch)
treee42ac6b5231ef5ad6182de3309b7d44e42d69fb5
parent829f4a279b817b01144c5896de1a5c671804857d (diff)
add withTransformedFunctions() to TensorFunction API
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
4 files changed, 50 insertions, 1 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index e45b13a6eb0..88872fef8a1 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2765,6 +2765,7 @@
"public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
"public java.util.List arguments()",
"public java.util.List selectorFunctions()",
+ "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)",
"public com.yahoo.tensor.functions.Slice withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
@@ -2810,7 +2811,8 @@
"public abstract java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public java.util.Optional asScalarFunction()",
"public java.lang.String toString()",
- "public abstract int hashCode()"
+ "public abstract int hashCode()",
+ "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)"
],
"fields" : [ ]
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 61d3acf6338..630eeb81d13 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -12,8 +12,10 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.List;
+import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.function.Function;
/**
* A function which is a tensor whose values are computed by individual lambda functions on evaluation.
@@ -82,6 +84,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ Map<TensorAddress, ScalarFunction<NAMETYPE>> transformedCells = new LinkedHashMap<>();
+ for (var orig : cells.entrySet()) {
+ var transformed = transformer.apply(orig.getValue());
+ transformedCells.put(orig.getKey(), transformed);
+ }
+ return new MappedDynamicTensor<>(type(), transformedCells);
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
@@ -134,6 +147,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ List<ScalarFunction<NAMETYPE>> transformedCells = new ArrayList<>();
+ for (var orig : cells) {
+ var transformed = transformer.apply(orig);
+ transformedCells.add(transformed);
+ }
+ return new IndexedDynamicTensor<>(type(), transformedCells);
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 87e24306031..066d75bcd9c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -16,6 +16,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
+import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@@ -56,6 +57,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ List<DimensionValue<NAMETYPE>> transformedAddress = new ArrayList<>();
+ for (var orig : subspaceAddress) {
+ var idxFun = orig.index();
+ if (idxFun.isPresent()) {
+ var transformed = transformer.apply(idxFun.get());
+ transformedAddress.add(new DimensionValue<NAMETYPE>(orig.dimension(), transformed));
+ } else {
+ transformedAddress.add(orig);
+ }
+ }
+ return new Slice<>(argument, transformedAddress);
+ }
+
@Override
public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 1)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 503f414d8eb..bf5eaeb6c2e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
import java.util.Optional;
+import java.util.function.Function;
/**
* A representation of a tensor function which is able to be translated to a set of primitive
@@ -72,4 +73,9 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
@Override
public abstract int hashCode();
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ return this;
+ }
}