diff options
11 files changed, 31 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 6baaea6ea05..272b668b5fb 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -55,7 +55,7 @@ public class OnnxModel extends DistributableResource { return ref.toString(); } // or a function (evaluated by backend) - if (ref.isSimple() && "rankingExpression".equals(ref.name())) { + if (ref.isSimpleRankingExpressionWrapper()) { var arg = ref.simpleArgument(); if (arg.isPresent()) { return ref.toString(); diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 7cb0a088f5f..a00bbb682a8 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1169,7 +1169,7 @@ public class RankProfile implements Cloneable { // Source is either a simple reference (query/attribute/constant/rankingExpression)... Optional<Reference> reference = Reference.simple(source); if (reference.isPresent()) { - if (reference.get().name().equals("rankingExpression") && reference.get().simpleArgument().isPresent()) { + if (reference.get().isSimpleRankingExpressionWrapper()) { source = reference.get().simpleArgument().get(); // look up function below } else { return Optional.of(context.getType(reference.get())); diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index 31a38752bec..acb125197d2 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -20,6 +20,7 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.vespa.config.search.RankProfilesConfig; +import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -273,9 +274,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { String propertyName = RankingExpression.propertyName(referenceNode.getName()); String expressionString = function.getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); - ReferenceNode backendReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", - referenceNode.getArguments().expressions(), - referenceNode.getOutput()); + var backendReferenceNode = new ReferenceNode(wrapInRankingExpression(referenceNode.getName()), + referenceNode.getArguments().expressions(), + referenceNode.getOutput()); // tell backend to map back to the name the user expects: featureRenames.put(backendReferenceNode.toString(), referenceNode.toString()); functionFeatures.put(referenceNode.getName(), backendReferenceNode); @@ -499,7 +500,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (expression.getRoot() instanceof ReferenceNode) { properties.add(new Pair<>("vespa.rank." + phase, expression.getRoot().toString())); } else { - properties.add(new Pair<>("vespa.rank." + phase, "rankingExpression(" + name + ")")); + properties.add(new Pair<>("vespa.rank." + phase, wrapInRankingExpression(name))); properties.add(new Pair<>(RankingExpression.propertyName(name), expression.getRoot().toString())); } return properties; @@ -520,7 +521,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { String source = mapping.getValue(); if (functionNames.contains(source)) { - onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); + onnxModel.addInputNameMapping(mapping.getKey(), wrapInRankingExpression(source)); } } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java index ebdbbb693f1..cce6b42d323 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java @@ -3,6 +3,7 @@ package com.yahoo.search.ranking; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; +import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; import java.util.function.Supplier; import java.util.logging.Logger; @@ -42,7 +43,7 @@ class HitRescorer { } } - private static final String RE_PREFIX = "rankingExpression("; + private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "("; private static final String RE_SUFFIX = ")"; private static final int RE_PRE_LEN = RE_PREFIX.length(); private static final int RE_SUF_LEN = RE_SUFFIX.length(); diff --git a/container-search/src/main/java/com/yahoo/search/result/FeatureData.java b/container-search/src/main/java/com/yahoo/search/result/FeatureData.java index 421f19475a6..7e9fa3f748a 100644 --- a/container-search/src/main/java/com/yahoo/search/result/FeatureData.java +++ b/container-search/src/main/java/com/yahoo/search/result/FeatureData.java @@ -11,6 +11,7 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression; import java.nio.charset.StandardCharsets; import java.util.Collections; @@ -144,7 +145,7 @@ public class FeatureData implements Inspectable, JsonProducer { if (featureValue.valid()) return featureValue; // Try to wrap by rankingExpression(name) - return value.field("rankingExpression(" + featureName + ")"); + return value.field(wrapInRankingExpression(featureName)); } /** Returns the names of the features available in this */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 46134074137..34e34a3341d 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import com.yahoo.collections.Pair; +import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression; import java.util.Objects; import java.util.Optional; @@ -51,7 +52,8 @@ class FunctionReference { } String serialForm() { - return "rankingExpression(" + name + (instance != null ? instance : "") + ")"; + String extra = (instance != null ? instance : ""); + return wrapInRankingExpression(name + extra); } @Override diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 81325740218..47c246c008e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; import java.util.Arrays; import java.util.HashMap; @@ -233,7 +234,11 @@ public final class LazyArrayContext extends Context implements ContextIndex { List<OnnxModel> onnxModels, Map<String, OnnxModel> onnxModelsInUse) { if (isFunctionReference(node)) { - FunctionReference reference = FunctionReference.fromSerial(node.toString()).get(); + var opt = FunctionReference.fromSerial(node.toString()); + if (opt.isEmpty()) { + throw new IllegalArgumentException("Could not extract function " + node + " from serialized form '" + node.toString() +"'"); + } + FunctionReference reference = opt.get(); bindTargets.add(reference.serialForm()); ExpressionFunction function = functions.get(reference); @@ -313,7 +318,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { private boolean isFunctionReference(ExpressionNode node) { if ( ! (node instanceof ReferenceNode reference)) return false; - return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1; + return reference.getName().equals(RANKING_EXPRESSION_WRAPPER) && reference.getArguments().size() == 1; } private boolean isOnnx(ExpressionNode node) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 171151bfdf4..c7d69d7a36a 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -10,6 +10,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; +import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -142,7 +143,7 @@ public class ExpressionFunction { if (shouldGenerateFeature(expr)) { String funcName = "autogenerated_ranking_feature@" + Long.toHexString(symbolCode(key + "=" + binding)); context.addFunctionSerialization(RankingExpression.propertyName(funcName), binding); - binding = "rankingExpression(" + funcName + ")"; + binding = wrapInRankingExpression(funcName); } argumentBindings.put(key, binding); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index c9f818544e3..c6de04ed755 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -11,6 +11,7 @@ import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.text.Text; +import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; import java.io.File; import java.io.FileNotFoundException; @@ -80,7 +81,7 @@ public class RankingExpression implements Serializable { private String name = ""; private ExpressionNode root; - private final static String RANKEXPRESSION = "rankingExpression("; + private final static String RANKEXPRESSION = RANKING_EXPRESSION_WRAPPER + "("; private final static String RANKINGSCRIPT = ").rankingScript"; private final static String EXPRESSION_NAME = ").expressionName"; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 85a12a49958..ec377c6f5d9 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; +import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression; import java.util.ArrayDeque; import java.util.Deque; @@ -95,7 +96,7 @@ public final class ReferenceNode extends CompositeNode { context.addFunctionTypeSerialization(functionName, function.returnType().get()); } path.removeLast(); - return string.append("rankingExpression(").append(functionName).append(')'); + return string.append(wrapInRankingExpression(functionName)); } // Not resolved in this context: output as-is diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 7d0c0b98910..e2fffd824b9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; +import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression; import java.util.Collection; import java.util.Collections; @@ -97,13 +98,13 @@ public class SerializationContext extends FunctionReferenceContext { /** Adds the serialization of the argument type to a function */ public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { - serializedFunctions.put("rankingExpression(" + functionName + ")." + argumentName + ".type", type.toString()); + serializedFunctions.put(wrapInRankingExpression(functionName) + "." + argumentName + ".type", type.toString()); } /** Adds the serialization of the return type of a function */ public void addFunctionTypeSerialization(String functionName, TensorType type) { if (type.rank() == 0) return; // no explicit type implies scalar (aka rank 0 tensor) - serializedFunctions.put("rankingExpression(" + functionName + ").type", type.toString()); + serializedFunctions.put(wrapInRankingExpression(functionName) + ".type", type.toString()); } @Override |