diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-20 14:57:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-20 14:57:58 +0100 |
commit | 3075eced6674caef07fed92b9e311bdda67718a5 (patch) | |
tree | 63d5bb40bdd307db4ba14071aa45134c10eabed3 | |
parent | e1502dcf57d9da6a7837a61fcca0cd7aa5e4f48e (diff) | |
parent | baa9c48be8732564e00730efe680df26a8f47f4c (diff) |
Merge pull request #26501 from vespa-engine/arnej/add-transformer-support
Arnej/add transformer support
17 files changed, 264 insertions, 29 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java index 4e320594918..a9eea3d2ead 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java @@ -2,14 +2,18 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import java.io.StringReader; import java.util.ArrayList; import java.util.List; @@ -22,6 +26,9 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof TensorFunctionNode tfn) { + node = tfn.withTransformedExpressions(expr -> transform(expr, context)); + } if (node instanceof ReferenceNode) { return transformFeature((ReferenceNode) node, context); } else if (node instanceof CompositeNode) { @@ -32,6 +39,32 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile } private ExpressionNode transformFeature(ReferenceNode node, RankProfileTransformContext context) { + Reference ref = node.reference(); + String name = ref.name(); + var args = ref.arguments(); + if (name.equals("onnx") && args.size() == 1) { + var arg = args.expressions().get(0); + var models = context.rankProfile().onnxModels(); + var model = models.get(arg.toString()); + if (model != null) { + for (var entry : model.getInputMap().entrySet()) { + String source = entry.getValue(); + var reader = new StringReader(source); + try { + var asExpression = new RankingExpression(reader); + String transformed = transform(asExpression.getRoot(), context).toString(); + if (! source.equals(transformed)) { + // not sure about this: + throw new IllegalStateException("unexpected rewrite: " + source + " => " + transformed + " for onnx input " + entry.getKey()); + // consider instead: model.addInputNameMapping(entry.getKey(), transformed, true); + } + } catch (ParseException e) { + throw new IllegalArgumentException("illegal onnx input '" + source + "': " + e.getMessage()); + } + } + return node; + } + } if ( ! node.getArguments().isEmpty() && ! FeatureNames.isSimpleFeature(node.reference())) { return transformArguments(node, context); } else { diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 84d0997b12e..af072c5b59a 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -55,20 +55,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } return transformChildren(t, childContext); } - if (f instanceof DynamicTensor d) { - for (var tf : d.cellGeneratorFunctions()) { - if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - transform(expr.wrappedExpression(), context); - } - } - } - if (f instanceof Slice s) { - for (var tf : s.selectorFunctions()) { - if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - transform(expr.wrappedExpression(), context); - } - } - } + node = t.withTransformedExpressions(expr -> transform(expr, context)); } if (node instanceof CompositeNode c) return transformChildren(c, context); diff --git a/config-model/src/test/derived/scalar_constant/rank-profiles.cfg b/config-model/src/test/derived/scalar_constant/rank-profiles.cfg new file mode 100644 index 00000000000..5b494fa46fb --- /dev/null +++ b/config-model/src/test/derived/scalar_constant/rank-profiles.cfg @@ -0,0 +1,18 @@ +rankprofile[].name "default" +rankprofile[].fef.property[].name "rankingExpression(makevector).rankingScript" +rankprofile[].fef.property[].value "tensor(x[3]):{{x:0}:0.25,{x:1}:0.5,{x:2}:0.75}" +rankprofile[].fef.property[].name "rankingExpression(makevector).type" +rankprofile[].fef.property[].value "tensor(x[3])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(rankingExpression(makevector), sum)" +rankprofile[].name "unranked" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "value(0)" +rankprofile[].fef.property[].name "vespa.hitcollector.heapsize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.hitcollector.arraysize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" +rankprofile[].fef.property[].value "true" diff --git a/config-model/src/test/derived/scalar_constant/test.sd b/config-model/src/test/derived/scalar_constant/test.sd new file mode 100644 index 00000000000..5c05e0ba941 --- /dev/null +++ b/config-model/src/test/derived/scalar_constant/test.sd @@ -0,0 +1,20 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +schema test { + document test { + field title type string { + indexing: index | summary + } + } + rank-profile default { + constants { + constant(foobar) double: 0.5 + } + function makevector() { + expression: tensor(x[3]):[0.25, constant(foobar), 0.75] + } + first-phase { + expression: sum(makevector()) + } + } +} diff --git a/config-model/src/test/derived/vector_constant/ax_plus_b.onnx b/config-model/src/test/derived/vector_constant/ax_plus_b.onnx new file mode 100644 index 00000000000..17282d13dc3 --- /dev/null +++ b/config-model/src/test/derived/vector_constant/ax_plus_b.onnx @@ -0,0 +1,23 @@ +:© + +matrix_X +vector_AXA"MatMul + +XA +vector_Bvector_Y"AddlrZ +matrix_X + + +Z +vector_A + + +Z +vector_B + + +b +vector_Y + + +B
\ No newline at end of file diff --git a/config-model/src/test/derived/vector_constant/rank-profiles.cfg b/config-model/src/test/derived/vector_constant/rank-profiles.cfg new file mode 100644 index 00000000000..87edec6ca24 --- /dev/null +++ b/config-model/src/test/derived/vector_constant/rank-profiles.cfg @@ -0,0 +1,26 @@ +rankprofile[].name "default" +rankprofile[].fef.property[].name "constant(bb).value" +rankprofile[].fef.property[].value "tensor(d0[2]):[4.0, 5.0]" +rankprofile[].fef.property[].name "constant(bb).type" +rankprofile[].fef.property[].value "tensor(d0[2])" +rankprofile[].fef.property[].name "constant(aa).value" +rankprofile[].fef.property[].value "tensor(d1[3]):[1.0, 2.0, 3.0]" +rankprofile[].fef.property[].name "constant(aa).type" +rankprofile[].fef.property[].value "tensor(d1[3])" +rankprofile[].fef.property[].name "rankingExpression(indirect_a).rankingScript" +rankprofile[].fef.property[].value "tensor(d1[3]):{{d1:0}:2.0,{d1:1}:(constant(aa){d1:0}),{d1:2}:(constant(bb){d0:(2.0)})}" +rankprofile[].fef.property[].name "rankingExpression(indirect_a).type" +rankprofile[].fef.property[].value "tensor(d1[3])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(onnx(inside).foobar, sum)" +rankprofile[].name "unranked" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "value(0)" +rankprofile[].fef.property[].name "vespa.hitcollector.heapsize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.hitcollector.arraysize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" +rankprofile[].fef.property[].value "true" diff --git a/config-model/src/test/derived/vector_constant/ranking-constants.cfg b/config-model/src/test/derived/vector_constant/ranking-constants.cfg new file mode 100644 index 00000000000..8637be8f175 --- /dev/null +++ b/config-model/src/test/derived/vector_constant/ranking-constants.cfg @@ -0,0 +1,3 @@ +constant[].name "xx" +constant[].fileref "const_xx.json" +constant[].type "tensor(d0[2],d1[3])" diff --git a/config-model/src/test/derived/vector_constant/test.sd b/config-model/src/test/derived/vector_constant/test.sd new file mode 100644 index 00000000000..508bd6505a7 --- /dev/null +++ b/config-model/src/test/derived/vector_constant/test.sd @@ -0,0 +1,33 @@ +schema test { + document test { + field extra type string { + } + } + constant xx { + file: const_xx.json + type: tensor(d0[2],d1[3]) + } + rank-profile default { + constants { + constant(aa) tensor(d1[3]): [1,2,3] + bb tensor(d0[2]): [4,5] + dd double: 2 + } + function indirect_a() { + expression: tensor(d1[3]): [constant(dd), constant(aa){d1:0}, constant(bb){d0:(constant(dd))}] + } + onnx-model inside { + file: ax_plus_b.onnx + input vector_A: indirect_a + input matrix_X: constant(xx) + input vector_B: constant(bb) + output vector_Y: foobar + } + first-phase { + expression: sum(onnx(inside).foobar) + } + # function unused() { + # expression: constant(aa)*constant(bb) + # } + } +} diff --git a/config-model/src/test/java/com/yahoo/schema/derived/SmallConstantsTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/SmallConstantsTestCase.java new file mode 100644 index 00000000000..09352eb59fa --- /dev/null +++ b/config-model/src/test/java/com/yahoo/schema/derived/SmallConstantsTestCase.java @@ -0,0 +1,27 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.derived; + +import com.yahoo.schema.parser.ParseException; +import org.junit.jupiter.api.Test; +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Tests constants in rank-profile + * + * @author arnej + */ +public class SmallConstantsTestCase extends AbstractExportingTestCase { + + @Test + void testScalarInRankProfile() throws IOException, ParseException { + assertCorrectDeriving("scalar_constant"); + } + + @Test + void testVectorInRankProfile() throws IOException, ParseException { + assertCorrectDeriving("vector_constant"); + } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java index bc71e51655c..8cdd5387acd 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java @@ -96,6 +96,11 @@ class BindingExtractor { for (ExpressionNode child : tfn.children()) { result.merge(extractBindTargets(child)); } + // ignore return value: + tfn.withTransformedExpressions(expr -> { + result.merge(extractBindTargets(expr)); + return expr; + }); var f = tfn.function(); if (f instanceof Generate) { var tt = f.type(null); @@ -103,20 +108,6 @@ class BindingExtractor { result.removeTarget(dim.name()); } } - else if (f instanceof Slice<?> s) { - for (var selectorFunc : s.selectorFunctions()) { - if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - result.merge(extractBindTargets(expr.wrappedExpression())); - } - } - } - else if (f instanceof DynamicTensor<?> d) { - for (var tf : d.cellGeneratorFunctions()) { - if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - result.merge(extractBindTargets(expr.wrappedExpression())); - } - } - } return result; } else if (isOnnx(node)) { diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index de3ec9648d1..f3fe86e261f 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1696,6 +1696,7 @@ "public void <init>(com.yahoo.tensor.functions.TensorFunction)", "public com.yahoo.tensor.functions.TensorFunction function()", "public java.util.List children()", + "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode withTransformedExpressions(java.util.function.Function)", "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)", "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 7577c65527b..41ece967491 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -25,6 +25,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -58,6 +59,25 @@ public class TensorFunctionNode extends CompositeNode { return new TensorFunctionNode(f); } + private static ScalarFunction<Reference> transform(ScalarFunction<Reference> input, + Function<ExpressionNode, ExpressionNode> transformer) + { + if (input instanceof ExpressionScalarFunction wrapper) { + ExpressionNode transformed = transformer.apply(wrapper.expression); + return new ExpressionScalarFunction(transformed); + } + return input; + } + + public ExpressionNode withTransformedExpressions(Function<ExpressionNode, ExpressionNode> transformer) { + if (function instanceof ExpressionTensorFunction etf) { + ExpressionNode orig = etf.expression; + return transformer.apply(orig); + } + TensorFunction<Reference> transformed = function.withTransformedFunctions(fun -> transform(fun, transformer)); + return new TensorFunctionNode(transformed); + } + @Override public CompositeNode setChildren(List<ExpressionNode> children) { List<TensorFunction<Reference>> wrappedChildren = children.stream() diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index 39afcfff541..225b260f403 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import java.util.ArrayList; import java.util.List; @@ -20,6 +21,9 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext @Override public ExpressionNode transform(ExpressionNode node, TransformContext context) { + if (node instanceof TensorFunctionNode tfn) { + node = tfn.withTransformedExpressions(expr -> transform(expr, context)); + } if (node instanceof ReferenceNode) return transformFeature((ReferenceNode) node, context); else if (node instanceof CompositeNode) diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e45b13a6eb0..88872fef8a1 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2765,6 +2765,7 @@ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", "public java.util.List arguments()", "public java.util.List selectorFunctions()", + "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)", "public com.yahoo.tensor.functions.Slice withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", @@ -2810,7 +2811,8 @@ "public abstract java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public java.util.Optional asScalarFunction()", "public java.lang.String toString()", - "public abstract int hashCode()" + "public abstract int hashCode()", + "public com.yahoo.tensor.functions.TensorFunction withTransformedFunctions(java.util.function.Function)" ], "fields" : [ ] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 61d3acf6338..630eeb81d13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -12,8 +12,10 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.List; +import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; +import java.util.function.Function; /** * A function which is a tensor whose values are computed by individual lambda functions on evaluation. @@ -82,6 +84,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens return result; } + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + Map<TensorAddress, ScalarFunction<NAMETYPE>> transformedCells = new LinkedHashMap<>(); + for (var orig : cells.entrySet()) { + var transformed = transformer.apply(orig.getValue()); + transformedCells.put(orig.getKey(), transformed); + } + return new MappedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); @@ -134,6 +147,17 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens return result; } + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + List<ScalarFunction<NAMETYPE>> transformedCells = new ArrayList<>(); + for (var orig : cells) { + var transformed = transformer.apply(orig); + transformedCells.add(transformed); + } + return new IndexedDynamicTensor<>(type(), transformedCells); + } + @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 87e24306031..066d75bcd9c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -16,6 +16,7 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -56,6 +57,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY return result; } + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + List<DimensionValue<NAMETYPE>> transformedAddress = new ArrayList<>(); + for (var orig : subspaceAddress) { + var idxFun = orig.index(); + if (idxFun.isPresent()) { + var transformed = transformer.apply(idxFun.get()); + transformedAddress.add(new DimensionValue<NAMETYPE>(orig.dimension(), transformed)); + } else { + transformedAddress.add(orig); + } + } + return new Slice<>(argument, transformedAddress); + } + @Override public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 1) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 503f414d8eb..bf5eaeb6c2e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; import java.util.Optional; +import java.util.function.Function; /** * A representation of a tensor function which is able to be translated to a set of primitive @@ -72,4 +73,9 @@ public abstract class TensorFunction<NAMETYPE extends Name> { @Override public abstract int hashCode(); + public TensorFunction<NAMETYPE> withTransformedFunctions( + Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> transformer) + { + return this; + } } |