diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-20 11:47:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-20 11:47:26 +0100 |
commit | 2306ee6febc80fb52bcb1f3d497e99807f3c1561 (patch) | |
tree | e23617d4ef5c90a335f6be92ca4b9f24c49eabb7 | |
parent | 3b399f5284d90e70417d381324f45396f92a6de1 (diff) | |
parent | 47e7d6a7509c5ba3a339afbaa0f17b16d9b382af (diff) |
Merge pull request #26480 from vespa-engine/arnej/handle-dynamic-tensor
Arnej/handle dynamic tensor
8 files changed, 60 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 993a2442fdb..84d0997b12e 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -12,6 +12,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.functions.DynamicTensor; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Slice; @@ -54,6 +55,13 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } return transformChildren(t, childContext); } + if (f instanceof DynamicTensor d) { + for (var tf : d.cellGeneratorFunctions()) { + if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { + transform(expr.wrappedExpression(), context); + } + } + } if (f instanceof Slice s) { for (var tf : s.selectorFunctions()) { if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index cf354a05a93..de12de9b747 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -289,6 +289,9 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform operators.add(Operator.plus); } } + if (operators.isEmpty() && factors.size() == 1) { + return factors.get(0); + } return new OperationNode(factors, operators); } diff --git a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg index 58be9b400aa..9d21691a910 100644 --- a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg +++ b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg @@ -124,6 +124,8 @@ rankprofile[].fef.property[].name "vespa.rank.globalphase" rankprofile[].fef.property[].value "rankingExpression(globalphase)" rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" rankprofile[].fef.property[].value "reduce(constant(ww) * onnx(another).foobar, sum)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(extra)" rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" rankprofile[].fef.property[].value "1001" rankprofile[].fef.property[].name "vespa.type.attribute.aa" diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java index 126e9f9f4e6..bc71e51655c 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.DynamicTensor; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Slice; @@ -109,6 +110,13 @@ class BindingExtractor { } } } + else if (f instanceof DynamicTensor<?> d) { + for (var tf : d.cellGeneratorFunctions()) { + if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { + result.merge(extractBindTargets(expr.wrappedExpression())); + } + } + } return result; } else if (isOnnx(node)) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java index 1c66686a9fe..7ebbf0d582b 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java @@ -14,6 +14,7 @@ import java.util.Deque; import java.util.Iterator; import java.util.List; import java.util.Objects; +import java.util.logging.Logger; /** * A sequence of binary operations. @@ -22,12 +23,21 @@ import java.util.Objects; */ public final class OperationNode extends CompositeNode { + private static final Logger logger = Logger.getLogger(OperationNode.class.getName()); + private final List<ExpressionNode> children; private final List<Operator> operators; public OperationNode(List<ExpressionNode> children, List<Operator> operators) { this.children = List.copyOf(children); this.operators = List.copyOf(operators); + if (operators.isEmpty()) { + logger.warning("Strange: no operators for OperationNode"); + } + int needChildren = operators.size() + 1; + if (needChildren != children.size()) { + throw new IllegalArgumentException("Need " + needChildren + " children, but got " + children.size()); + } } public OperationNode(ExpressionNode leftExpression, Operator operator, ExpressionNode rightExpression) { @@ -70,12 +80,14 @@ public final class OperationNode extends CompositeNode { if ( parent == null) return false; if ( ! (parent instanceof OperationNode operationParent)) return false; - // The line below can only be correct in both only have one operator. + // The last line below can only be correct if both only have one operator. // Getting this correct is impossible without more work. // So for now we only handle the simple case correctly, and use a safe approach by adding // extra parenthesis just in case.... - return operationParent.operators.get(0).hasPrecedenceOver(this.operators.get(0)) - || ((operationParent.operators.size() > 1) || (operators.size() > 1)); + if ((operationParent.operators.size() != 1) || (operators.size() != 1)) { + return true; + } + return operationParent.operators.get(0).hasPrecedenceOver(this.operators.get(0)); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 1c1f7509ce8..4293ff29d0b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -53,6 +53,9 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { List<Operator> operators = new ArrayList<>(node.operators()); for (Operator operator : Operator.operatorsByPrecedence) transform(operator, children, operators); + if (operators.isEmpty() && children.size() == 1) { + return children.get(0); + } node = new OperationNode(children, operators); } @@ -69,7 +72,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { int i = 0; while (i < children.size()-1) { boolean transformed = false; - if ( operators.get(i).equals(operatorToTransform)) { + if (operators.get(i).equals(operatorToTransform)) { ExpressionNode child1 = children.get(i); ExpressionNode child2 = children.get(i + 1); if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) { diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 011e7e4a31d..e45b13a6eb0 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1730,6 +1730,7 @@ "methods" : [ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public java.util.List arguments()", + "public abstract java.util.List cellGeneratorFunctions()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 558b01baa02..61d3acf6338 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -33,6 +34,8 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens @Override public List<TensorFunction<NAMETYPE>> arguments() { return List.of(); } + public abstract List<TensorFunction<NAMETYPE>> cellGeneratorFunctions(); + @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 0) @@ -71,6 +74,14 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens this.cells = ImmutableMap.copyOf(cells); } + public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { + var result = new ArrayList<TensorFunction<NAMETYPE>>(); + for (var fun : cells.values()) { + fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + } + return result; + } + @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); @@ -115,6 +126,14 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens this.cells = List.copyOf(cells); } + public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { + var result = new ArrayList<TensorFunction<NAMETYPE>>(); + for (var fun : cells) { + fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + } + return result; + } + @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); |