summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-20 14:57:58 +0100
committerGitHub <noreply@github.com>2023-03-20 14:57:58 +0100
commit3075eced6674caef07fed92b9e311bdda67718a5 (patch)
tree63d5bb40bdd307db4ba14071aa45134c10eabed3
parente1502dcf57d9da6a7837a61fcca0cd7aa5e4f48e (diff)
parentbaa9c48be8732564e00730efe680df26a8f47f4c (diff)
Merge pull request #26501 from vespa-engine/arnej/add-transformer-support
Arnej/add transformer support
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java33
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java15
-rw-r--r--config-model/src/test/derived/scalar_constant/rank-profiles.cfg18
-rw-r--r--config-model/src/test/derived/scalar_constant/test.sd20
-rw-r--r--config-model/src/test/derived/vector_constant/ax_plus_b.onnx23
-rw-r--r--config-model/src/test/derived/vector_constant/rank-profiles.cfg26
-rw-r--r--config-model/src/test/derived/vector_constant/ranking-constants.cfg3
-rw-r--r--config-model/src/test/derived/vector_constant/test.sd33
-rw-r--r--config-model/src/test/java/com/yahoo/schema/derived/SmallConstantsTestCase.java27
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java19
-rw-r--r--searchlib/abi-spec.json1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java4
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
17 files changed, 264 insertions, 29 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java
index 4e320594918..a9eea3d2ead 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java
@@ -2,14 +2,18 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
@@ -22,6 +26,9 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof TensorFunctionNode tfn) {
+ node = tfn.withTransformedExpressions(expr -> transform(expr, context));
+ }
if (node instanceof ReferenceNode) {
return transformFeature((ReferenceNode) node, context);
} else if (node instanceof CompositeNode) {
@@ -32,6 +39,32 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile
}
private ExpressionNode transformFeature(ReferenceNode node, RankProfileTransformContext context) {
+ Reference ref = node.reference();
+ String name = ref.name();
+ var args = ref.arguments();
+ if (name.equals("onnx") && args.size() == 1) {
+ var arg = args.expressions().get(0);
+ var models = context.rankProfile().onnxModels();
+ var model = models.get(arg.toString());
+ if (model != null) {
+ for (var entry : model.getInputMap().entrySet()) {
+ String source = entry.getValue();
+ var reader = new StringReader(source);
+ try {
+ var asExpression = new RankingExpression(reader);
+ String transformed = transform(asExpression.getRoot(), context).toString();
+ if (! source.equals(transformed)) {
+ // not sure about this:
+ throw new IllegalStateException("unexpected rewrite: " + source + " => " + transformed + " for onnx input " + entry.getKey());
+ // consider instead: model.addInputNameMapping(entry.getKey(), transformed, true);
+ }
+ } catch (ParseException e) {
+ throw new IllegalArgumentException("illegal onnx input '" + source + "': " + e.getMessage());
+ }
+ }
+ return node;
+ }
+ }
if ( ! node.getArguments().isEmpty() && ! FeatureNames.isSimpleFeature(node.reference())) {
return transformArguments(node, context);
} else {
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 84d0997b12e..af072c5b59a 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -55,20 +55,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
}
return transformChildren(t, childContext);
}
- if (f instanceof DynamicTensor d) {
- for (var tf : d.cellGeneratorFunctions()) {
- if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- transform(expr.wrappedExpression(), context);
- }
- }
- }
- if (f instanceof Slice s) {
- for (var tf : s.selectorFunctions()) {
- if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- transform(expr.wrappedExpression(), context);
- }
- }
- }
+ node = t.withTransformedExpressions(expr -> transform(expr, context));
}
if (node instanceof CompositeNode c)
return transformChildren(c, context);
diff --git a/config-model/src/test/derived/scalar_constant/rank-profiles.cfg b/config-model/src/test/derived/scalar_constant/rank-profiles.cfg
new file mode 100644
index 00000000000..5b494fa46fb
--- /dev/null
+++ b/config-model/src/test/derived/scalar_constant/rank-profiles.cfg
@@ -0,0 +1,18 @@
+rankprofile[].name "default"
+rankprofile[].fef.property[].name "rankingExpression(makevector).rankingScript"
+rankprofile[].fef.property[].value "tensor(x[3]):{{x:0}:0.25,{x:1}:0.5,{x:2}:0.75}"
+rankprofile[].fef.property[].name "rankingExpression(makevector).type"
+rankprofile[].fef.property[].value "tensor(x[3])"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "rankingExpression(firstphase)"
+rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
+rankprofile[].fef.property[].value "reduce(rankingExpression(makevector), sum)"
+rankprofile[].name "unranked"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "value(0)"
+rankprofile[].fef.property[].name "vespa.hitcollector.heapsize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.hitcollector.arraysize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures"
+rankprofile[].fef.property[].value "true"
diff --git a/config-model/src/test/derived/scalar_constant/test.sd b/config-model/src/test/derived/scalar_constant/test.sd
new file mode 100644
index 00000000000..5c05e0ba941
--- /dev/null
+++ b/config-model/src/test/derived/scalar_constant/test.sd
@@ -0,0 +1,20 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+schema test {
+ document test {
+ field title type string {
+ indexing: index | summary
+ }
+ }
+ rank-profile default {
+ constants {
+ constant(foobar) double: 0.5
+ }
+ function makevector() {
+ expression: tensor(x[3]):[0.25, constant(foobar), 0.75]
+ }
+ first-phase {
+ expression: sum(makevector())
+ }
+ }
+}
diff --git a/config-model/src/test/derived/vector_constant/ax_plus_b.onnx b/config-model/src/test/derived/vector_constant/ax_plus_b.onnx
new file mode 100644
index 00000000000..17282d13dc3
--- /dev/null
+++ b/config-model/src/test/derived/vector_constant/ax_plus_b.onnx
@@ -0,0 +1,23 @@
+:©
+
+matrix_X
+vector_AXA"MatMul
+
+XA
+vector_Bvector_Y"AddlrZ
+matrix_X
+  
+
+Z
+vector_A
+
+ 
+Z
+vector_B
+
+ 
+b
+vector_Y
+
+ 
+B \ No newline at end of file
diff --git a/config-model/src/test/derived/vector_constant/rank-profiles.cfg b/config-model/src/test/derived/vector_constant/rank-profiles.cfg
new file mode 100644
index 00000000000..87edec6ca24
--- /dev/null
+++ b/config-model/src/test/derived/vector_constant/rank-profiles.cfg
@@ -0,0 +1,26 @@
+rankprofile[].name "default"
+rankprofile[].fef.property[].name "constant(bb).value"
+rankprofile[].fef.property[].value "tensor(d0[2]):[4.0, 5.0]"
+rankprofile[].fef.property[].name "constant(bb).type"
+rankprofile[].fef.property[].value "tensor(d0[2])"
+rankprofile[].fef.property[].name "constant(aa).value"
+rankprofile[].fef.property[].value "tensor(d1[3]):[1.0, 2.0, 3.0]"
+rankprofile[].fef.property[].name "constant(aa).type"
+rankprofile[].fef.property[].value "tensor(d1[3])"
+rankprofile[].fef.property[].name "rankingExpression(indirect_a).rankingScript"
+rankprofile[].fef.property[].value "tensor(d1[3]):{{d1:0}:2.0,{d1:1}:(constant(aa){d1:0}),{d1:2}:(constant(bb){d0:(2.0)})}"
+rankprofile[].fef.property[].name "rankingExpression(indirect_a).type"
+rankprofile[].fef.property[].value "tensor(d1[3])"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "rankingExpression(firstphase)"
+rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
+rankprofile[].fef.property[].value "reduce(onnx(inside).foobar, sum)"
+rankprofile[].name "unranked"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "value(0)"
+rankprofile[].fef.property[].name "vespa.hitcollector.heapsize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.hitcollector.arraysize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures"
+rankprofile[].fef.property[].value "true"
diff --git a/config-model/src/test/derived/vector_constant/ranking-constants.cfg b/config-model/src/test/derived/vector_constant/ranking-constants.cfg
new file mode 100644
index 00000000000..8637be8f175
--- /dev/null
+++ b/config-model/src/test/derived/vector_constant/ranking-constants.cfg
@@ -0,0 +1,3 @@
+constant[].name "xx"
+constant[].fileref "const_xx.json"
+constant[].type "tensor(d0[2],d1[3])"
diff --git a/config-model/src/test/derived/vector_constant/test.sd b/config-model/src/test/derived/vector_constant/test.sd
new file mode 100644
index 00000000000..508bd6505a7
--- /dev/null
+++ b/config-model/src/test/derived/vector_constant/test.sd
@@ -0,0 +1,33 @@
+schema test {
+ document test {
+ field extra type string {
+ }
+ }
+ constant xx {
+ file: const_xx.json
+ type: tensor(d0[2],d1[3])
+ }
+ rank-profile default {
+ constants {
+ constant(aa) tensor(d1[3]): [1,2,3]
+ bb tensor(d0[2]): [4,5]
+ dd double: 2
+ }
+ function indirect_a() {
+ expression: tensor(d1[3]): [constant(dd), constant(aa){d1:0}, constant(bb){d0:(constant(dd))}]
+ }
+ onnx-model inside {
+ file: ax_plus_b.onnx
+ input vector_A: indirect_a
+ input matrix_X: constant(xx)
+ input vector_B: constant(bb)
+ output vector_Y: foobar
+ }
+ first-phase {
+ expression: sum(onnx(inside).foobar)
+ }
+ # function unused() {
+ # expression: constant(aa)*constant(bb)
+ # }
+ }
+}
diff --git a/config-model/src/test/java/com/yahoo/schema/derived/SmallConstantsTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/SmallConstantsTestCase.java
new file mode 100644
index 00000000000..09352eb59fa
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/schema/derived/SmallConstantsTestCase.java
@@ -0,0 +1,27 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.derived;
+
+import com.yahoo.schema.parser.ParseException;
+import org.junit.jupiter.api.Test;
+import java.io.IOException;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Tests constants in rank-profile
+ *
+ * @author arnej
+ */
+public class SmallConstantsTestCase extends AbstractExportingTestCase {
+
+ @Test
+ void testScalarInRankProfile() throws IOException, ParseException {
+ assertCorrectDeriving("scalar_constant");
+ }
+
+ @Test
+ void testVectorInRankProfile() throws IOException, ParseException {
+ assertCorrectDeriving("vector_constant");
+ }
+
+}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
index bc71e51655c..8cdd5387acd 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
@@ -96,6 +96,11 @@ class BindingExtractor {
for (ExpressionNode child : tfn.children()) {
result.merge(extractBindTargets(child));
}
+ // ignore return value:
+ tfn.withTransformedExpressions(expr -> {
+ result.merge(extractBindTargets(expr));
+ return expr;
+ });
var f = tfn.function();
if (f instanceof Generate) {
var tt = f.type(null);
@@ -103,20 +108,6 @@ class BindingExtractor {
result.removeTarget(dim.name());
}
}
- else if (f instanceof Slice<?> s) {
- for (var selectorFunc : s.selectorFunctions()) {
- if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- result.merge(extractBindTargets(expr.wrappedExpression()));
- }
- }
- }
- else if (f instanceof DynamicTensor<?> d) {
- for (var tf : d.cellGeneratorFunctions()) {
- if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- result.merge(extractBindTargets(expr.wrappedExpression()));
- }
- }
- }
return result;
}
else if (isOnnx(node)) {
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index de3ec9648d1..f3fe86e261f 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1696,6 +1696,7 @@
"public void <init>(com.yahoo.tensor.functions.TensorFunction)",
"public com.yahoo.tensor.functions.TensorFunction function()",
"public java.util.List children()",
+ "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode withTransformedExpressions(java.util.function.Function)",
"public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 7577c65527b..41ece967491 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.function.Function;
import java.util.stream.Collectors;
/**
@@ -58,6 +59,25 @@ public class TensorFunctionNode extends CompositeNode {
return new TensorFunctionNode(f);
}
+ private static ScalarFunction<Reference> transform(ScalarFunction<Reference> input,
+ Function<ExpressionNode, ExpressionNode> transformer)
+ {
+ if (input instanceof ExpressionScalarFunction wrapper) {
+ ExpressionNode transformed = transformer.apply(wrapper.expression);
+ return new ExpressionScalarFunction(transformed);
+ }
+ return input;
+ }
+
+ public ExpressionNode withTransformedExpressions(Function<ExpressionNode, ExpressionNode> transformer) {
+ if (function instanceof ExpressionTensorFunction etf) {
+ ExpressionNode orig = etf.expression;
+ return transformer.apply(orig);
+ }
+ TensorFunction<Reference> transformed = function.withTransformedFunctions(fun -> transform(fun, transformer));
+ return new TensorFunctionNode(transformed);
+ }
+
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction<Reference>> wrappedChildren = children.stream()
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
index 39afcfff541..225b260f403 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import java.util.ArrayList;
import java.util.List;
@@ -20,6 +21,9 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext
@Override
public ExpressionNode transform(ExpressionNode node, TransformContext context) {
+ if (node instanceof TensorFunctionNode tfn) {
+ node = tfn.withTransformedExpressions(expr -> transform(expr, context));
+ }
if (node instanceof ReferenceNode)
return transformFeature((ReferenceNode) node, context);
else if (node instanceof CompositeNode)
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index e45b13a6eb0..88872fef8a1 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2765,6 +2765,7 @@
"public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
"public java.util.List arguments()",
"public java.util.List selectorFunctions()",
+ "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)",
"public com.yahoo.tensor.functions.Slice withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
@@ -2810,7 +2811,8 @@
"public abstract java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public java.util.Optional asScalarFunction()",
"public java.lang.String toString()",
- "public abstract int hashCode()"
+ "public abstract int hashCode()",
+ "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)"
],
"fields" : [ ]
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 61d3acf6338..630eeb81d13 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -12,8 +12,10 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.List;
+import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
+import java.util.function.Function;
/**
* A function which is a tensor whose values are computed by individual lambda functions on evaluation.
@@ -82,6 +84,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ Map<TensorAddress, ScalarFunction<NAMETYPE>> transformedCells = new LinkedHashMap<>();
+ for (var orig : cells.entrySet()) {
+ var transformed = transformer.apply(orig.getValue());
+ transformedCells.put(orig.getKey(), transformed);
+ }
+ return new MappedDynamicTensor<>(type(), transformedCells);
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
@@ -134,6 +147,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ List<ScalarFunction<NAMETYPE>> transformedCells = new ArrayList<>();
+ for (var orig : cells) {
+ var transformed = transformer.apply(orig);
+ transformedCells.add(transformed);
+ }
+ return new IndexedDynamicTensor<>(type(), transformedCells);
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 87e24306031..066d75bcd9c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -16,6 +16,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
+import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@@ -56,6 +57,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return result;
}
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ List<DimensionValue<NAMETYPE>> transformedAddress = new ArrayList<>();
+ for (var orig : subspaceAddress) {
+ var idxFun = orig.index();
+ if (idxFun.isPresent()) {
+ var transformed = transformer.apply(idxFun.get());
+ transformedAddress.add(new DimensionValue<NAMETYPE>(orig.dimension(), transformed));
+ } else {
+ transformedAddress.add(orig);
+ }
+ }
+ return new Slice<>(argument, transformedAddress);
+ }
+
@Override
public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 1)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 503f414d8eb..bf5eaeb6c2e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
import java.util.Optional;
+import java.util.function.Function;
/**
* A representation of a tensor function which is able to be translated to a set of primitive
@@ -72,4 +73,9 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
@Override
public abstract int hashCode();
+ public TensorFunction<NAMETYPE> withTransformedFunctions(
+ Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer)
+ {
+ return this;
+ }
}