summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-20 11:47:26 +0100
committerGitHub <noreply@github.com>2023-03-20 11:47:26 +0100
commit2306ee6febc80fb52bcb1f3d497e99807f3c1561 (patch)
treee23617d4ef5c90a335f6be92ca4b9f24c49eabb7
parent3b399f5284d90e70417d381324f45396f92a6de1 (diff)
parent47e7d6a7509c5ba3a339afbaa0f17b16d9b382af (diff)
Merge pull request #26480 from vespa-engine/arnej/handle-dynamic-tensor
Arnej/handle dynamic tensor
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java8
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java3
-rw-r--r--config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java8
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java5
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java19
8 files changed, 60 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 993a2442fdb..84d0997b12e 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -12,6 +12,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.functions.DynamicTensor;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Slice;
@@ -54,6 +55,13 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
}
return transformChildren(t, childContext);
}
+ if (f instanceof DynamicTensor d) {
+ for (var tf : d.cellGeneratorFunctions()) {
+ if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
+ transform(expr.wrappedExpression(), context);
+ }
+ }
+ }
if (f instanceof Slice s) {
for (var tf : s.selectorFunctions()) {
if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index cf354a05a93..de12de9b747 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -289,6 +289,9 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
operators.add(Operator.plus);
}
}
+ if (operators.isEmpty() && factors.size() == 1) {
+ return factors.get(0);
+ }
return new OperationNode(factors, operators);
}
diff --git a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
index 58be9b400aa..9d21691a910 100644
--- a/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
+++ b/config-model/src/test/derived/globalphase_onnx_inside/rank-profiles.cfg
@@ -124,6 +124,8 @@ rankprofile[].fef.property[].name "vespa.rank.globalphase"
rankprofile[].fef.property[].value "rankingExpression(globalphase)"
rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript"
rankprofile[].fef.property[].value "reduce(constant(ww) * onnx(another).foobar, sum)"
+rankprofile[].fef.property[].name "vespa.match.feature"
+rankprofile[].fef.property[].value "attribute(extra)"
rankprofile[].fef.property[].name "vespa.globalphase.rerankcount"
rankprofile[].fef.property[].value "1001"
rankprofile[].fef.property[].name "vespa.type.attribute.aa"
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
index 126e9f9f4e6..bc71e51655c 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.DynamicTensor;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Slice;
@@ -109,6 +110,13 @@ class BindingExtractor {
}
}
}
+ else if (f instanceof DynamicTensor<?> d) {
+ for (var tf : d.cellGeneratorFunctions()) {
+ if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
+ result.merge(extractBindTargets(expr.wrappedExpression()));
+ }
+ }
+ }
return result;
}
else if (isOnnx(node)) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
index 1c66686a9fe..7ebbf0d582b 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
@@ -14,6 +14,7 @@ import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
+import java.util.logging.Logger;
/**
* A sequence of binary operations.
@@ -22,12 +23,21 @@ import java.util.Objects;
*/
public final class OperationNode extends CompositeNode {
+ private static final Logger logger = Logger.getLogger(OperationNode.class.getName());
+
private final List<ExpressionNode> children;
private final List<Operator> operators;
public OperationNode(List<ExpressionNode> children, List<Operator> operators) {
this.children = List.copyOf(children);
this.operators = List.copyOf(operators);
+ if (operators.isEmpty()) {
+ logger.warning("Strange: no operators for OperationNode");
+ }
+ int needChildren = operators.size() + 1;
+ if (needChildren != children.size()) {
+ throw new IllegalArgumentException("Need " + needChildren + " children, but got " + children.size());
+ }
}
public OperationNode(ExpressionNode leftExpression, Operator operator, ExpressionNode rightExpression) {
@@ -70,12 +80,14 @@ public final class OperationNode extends CompositeNode {
if ( parent == null) return false;
if ( ! (parent instanceof OperationNode operationParent)) return false;
- // The line below can only be correct in both only have one operator.
+ // The last line below can only be correct if both only have one operator.
// Getting this correct is impossible without more work.
// So for now we only handle the simple case correctly, and use a safe approach by adding
// extra parenthesis just in case....
- return operationParent.operators.get(0).hasPrecedenceOver(this.operators.get(0))
- || ((operationParent.operators.size() > 1) || (operators.size() > 1));
+ if ((operationParent.operators.size() != 1) || (operators.size() != 1)) {
+ return true;
+ }
+ return operationParent.operators.get(0).hasPrecedenceOver(this.operators.get(0));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 1c1f7509ce8..4293ff29d0b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -53,6 +53,9 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
List<Operator> operators = new ArrayList<>(node.operators());
for (Operator operator : Operator.operatorsByPrecedence)
transform(operator, children, operators);
+ if (operators.isEmpty() && children.size() == 1) {
+ return children.get(0);
+ }
node = new OperationNode(children, operators);
}
@@ -69,7 +72,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
int i = 0;
while (i < children.size()-1) {
boolean transformed = false;
- if ( operators.get(i).equals(operatorToTransform)) {
+ if (operators.get(i).equals(operatorToTransform)) {
ExpressionNode child1 = children.get(i);
ExpressionNode child2 = children.get(i + 1);
if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) {
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 011e7e4a31d..e45b13a6eb0 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1730,6 +1730,7 @@
"methods" : [
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public java.util.List arguments()",
+ "public abstract java.util.List cellGeneratorFunctions()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 558b01baa02..61d3acf6338 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -33,6 +34,8 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
@Override
public List<TensorFunction<NAMETYPE>> arguments() { return List.of(); }
+ public abstract List<TensorFunction<NAMETYPE>> cellGeneratorFunctions();
+
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 0)
@@ -71,6 +74,14 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
this.cells = ImmutableMap.copyOf(cells);
}
+ public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() {
+ var result = new ArrayList<TensorFunction<NAMETYPE>>();
+ for (var fun : cells.values()) {
+ fun.asTensorFunction().ifPresent(tf -> result.add(tf));
+ }
+ return result;
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
@@ -115,6 +126,14 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
this.cells = List.copyOf(cells);
}
+ public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() {
+ var result = new ArrayList<TensorFunction<NAMETYPE>>();
+ for (var fun : cells) {
+ fun.asTensorFunction().ifPresent(tf -> result.add(tf));
+ }
+ return result;
+ }
+
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());