summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-02 11:05:41 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-02 11:09:17 +0000
commitc9b02558e924375676a540c3fc8acae0ceafd886 (patch)
tree78ec59f02efe238432aa2660102cfbcaa53e30ae /model-evaluation
parentfe2e8db7891c39559622ab4c3bbfc3fc5275fe1f (diff)
use common utility and constant
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java9
2 files changed, 10 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 46134074137..34e34a3341d 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -2,6 +2,7 @@
package ai.vespa.models.evaluation;
import com.yahoo.collections.Pair;
+import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression;
import java.util.Objects;
import java.util.Optional;
@@ -51,7 +52,8 @@ class FunctionReference {
}
String serialForm() {
- return "rankingExpression(" + name + (instance != null ? instance : "") + ")";
+ String extra = (instance != null ? instance : "");
+ return wrapInRankingExpression(name + extra);
}
@Override
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 81325740218..47c246c008e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER;
import java.util.Arrays;
import java.util.HashMap;
@@ -233,7 +234,11 @@ public final class LazyArrayContext extends Context implements ContextIndex {
List<OnnxModel> onnxModels,
Map<String, OnnxModel> onnxModelsInUse) {
if (isFunctionReference(node)) {
- FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
+ var opt = FunctionReference.fromSerial(node.toString());
+ if (opt.isEmpty()) {
+ throw new IllegalArgumentException("Could not extract function " + node + " from serialized form '" + node.toString() +"'");
+ }
+ FunctionReference reference = opt.get();
bindTargets.add(reference.serialForm());
ExpressionFunction function = functions.get(reference);
@@ -313,7 +318,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private boolean isFunctionReference(ExpressionNode node) {
if ( ! (node instanceof ReferenceNode reference)) return false;
- return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
+ return reference.getName().equals(RANKING_EXPRESSION_WRAPPER) && reference.getArguments().size() == 1;
}
private boolean isOnnx(ExpressionNode node) {