diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-02-22 14:03:30 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-02-22 14:03:30 +0000 |
commit | 3aba185329c49b63c84365eda4e421cb7abd05cc (patch) | |
tree | 626b4509a081093ba454a4bf911325317faeacfa /config-model | |
parent | 890e0ac9e795ca1c95e459f98a54593ac151051c (diff) |
validate onnx model input/output mappings
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/OnnxModel.java | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index ae6f1fd96e4..6baaea6ea05 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -3,6 +3,7 @@ package com.yahoo.schema; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.ml.OnnxModelInfo; +import com.yahoo.searchlib.rankingexpression.Reference; import java.util.Collections; import java.util.HashMap; @@ -44,11 +45,37 @@ public class OnnxModel extends DistributableResource { addInputNameMapping(onnxName, vespaName, true); } + private String validateInputSource(String source) { + var optRef = Reference.simple(source); + if (optRef.isPresent()) { + Reference ref = optRef.get(); + // input can be one of: + // attribute(foo), query(foo), constant(foo) + if (FeatureNames.isSimpleFeature(ref)) { + return ref.toString(); + } + // or a function (evaluated by backend) + if (ref.isSimple() && "rankingExpression".equals(ref.name())) { + var arg = ref.simpleArgument(); + if (arg.isPresent()) { + return ref.toString(); + } + } + } else { + // otherwise it must be an identifier + Reference ref = Reference.fromIdentifier(source); + return ref.toString(); + } + // invalid input source + throw new IllegalArgumentException("invalid input for ONNX model " + getName() + ": " + source); + } + public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + String source = validateInputSource(vespaName); if (overwrite || ! inputMap.containsKey(onnxName)) { - inputMap.put(onnxName, vespaName); + inputMap.put(onnxName, source); } } @@ -59,8 +86,10 @@ public class OnnxModel extends DistributableResource { public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + // output name must be a valid identifier: + var ref = Reference.fromIdentifier(vespaName); if (overwrite || ! outputMap.containsKey(onnxName)) { - outputMap.put(onnxName, vespaName); + outputMap.put(onnxName, ref.toString()); } } |