summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-22 14:03:30 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-22 14:03:30 +0000
commit3aba185329c49b63c84365eda4e421cb7abd05cc (patch)
tree626b4509a081093ba454a4bf911325317faeacfa /config-model
parent890e0ac9e795ca1c95e459f98a54593ac151051c (diff)
validate onnx model input/output mappings
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java33
1 files changed, 31 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index ae6f1fd96e4..6baaea6ea05 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -3,6 +3,7 @@ package com.yahoo.schema;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
+import com.yahoo.searchlib.rankingexpression.Reference;
import java.util.Collections;
import java.util.HashMap;
@@ -44,11 +45,37 @@ public class OnnxModel extends DistributableResource {
addInputNameMapping(onnxName, vespaName, true);
}
+ private String validateInputSource(String source) {
+ var optRef = Reference.simple(source);
+ if (optRef.isPresent()) {
+ Reference ref = optRef.get();
+ // input can be one of:
+ // attribute(foo), query(foo), constant(foo)
+ if (FeatureNames.isSimpleFeature(ref)) {
+ return ref.toString();
+ }
+ // or a function (evaluated by backend)
+ if (ref.isSimple() && "rankingExpression".equals(ref.name())) {
+ var arg = ref.simpleArgument();
+ if (arg.isPresent()) {
+ return ref.toString();
+ }
+ }
+ } else {
+ // otherwise it must be an identifier
+ Reference ref = Reference.fromIdentifier(source);
+ return ref.toString();
+ }
+ // invalid input source
+ throw new IllegalArgumentException("invalid input for ONNX model " + getName() + ": " + source);
+ }
+
public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ String source = validateInputSource(vespaName);
if (overwrite || ! inputMap.containsKey(onnxName)) {
- inputMap.put(onnxName, vespaName);
+ inputMap.put(onnxName, source);
}
}
@@ -59,8 +86,10 @@ public class OnnxModel extends DistributableResource {
public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ // output name must be a valid identifier:
+ var ref = Reference.fromIdentifier(vespaName);
if (overwrite || ! outputMap.containsKey(onnxName)) {
- outputMap.put(onnxName, vespaName);
+ outputMap.put(onnxName, ref.toString());
}
}