diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-02 11:05:41 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-02 11:09:17 +0000 |
commit | c9b02558e924375676a540c3fc8acae0ceafd886 (patch) | |
tree | 78ec59f02efe238432aa2660102cfbcaa53e30ae /model-evaluation | |
parent | fe2e8db7891c39559622ab4c3bbfc3fc5275fe1f (diff) |
use common utility and constant
Diffstat (limited to 'model-evaluation')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java | 4 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java | 9 |
2 files changed, 10 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 46134074137..34e34a3341d 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import com.yahoo.collections.Pair; +import static com.yahoo.searchlib.rankingexpression.Reference.wrapInRankingExpression; import java.util.Objects; import java.util.Optional; @@ -51,7 +52,8 @@ class FunctionReference { } String serialForm() { - return "rankingExpression(" + name + (instance != null ? instance : "") + ")"; + String extra = (instance != null ? instance : ""); + return wrapInRankingExpression(name + extra); } @Override diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 81325740218..47c246c008e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -16,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; import java.util.Arrays; import java.util.HashMap; @@ -233,7 +234,11 @@ public final class LazyArrayContext extends Context implements ContextIndex { List<OnnxModel> onnxModels, Map<String, OnnxModel> onnxModelsInUse) { if (isFunctionReference(node)) { - FunctionReference reference = FunctionReference.fromSerial(node.toString()).get(); + var opt = FunctionReference.fromSerial(node.toString()); + if (opt.isEmpty()) { + throw new IllegalArgumentException("Could not extract function " + node + " from serialized form '" + node.toString() +"'"); + } + FunctionReference reference = opt.get(); bindTargets.add(reference.serialForm()); ExpressionFunction function = functions.get(reference); @@ -313,7 +318,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { private boolean isFunctionReference(ExpressionNode node) { if ( ! (node instanceof ReferenceNode reference)) return false; - return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1; + return reference.getName().equals(RANKING_EXPRESSION_WRAPPER) && reference.getArguments().size() == 1; } private boolean isOnnx(ExpressionNode node) { |