diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index c1f973300d6..68bebfa6183 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import ai.onnxruntime.TensorInfo; import ai.onnxruntime.ValueInfo; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -25,6 +26,7 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; @@ -38,7 +40,8 @@ class TensorConverter { { Map<String, OnnxTensor> result = new HashMap<>(); for (String name : tensorMap.keySet()) { - Tensor vespaTensor = tensorMap.get(name); + Tensor vespaTensor = tensorMap.get(name); + name = toOnnxName(name, session.getInputInfo().keySet()); TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo()); OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env); result.put(name, onnxTensor); @@ -143,7 +146,22 @@ class TensorConverter { } static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { - return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo()))); + return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), + e -> toVespaType(e.getValue().getInfo()))); + } + + static String asValidName(String name) { + return OnnxImporter.asValidIdentifier(name); + } + + static String toOnnxName(String name, Set<String> onnxNames) { + if (onnxNames.contains(name)) + return name; + for (String onnxName : onnxNames) { + if (asValidName(onnxName).equals(name)) + return onnxName; + } + throw new IllegalArgumentException("ONNX model has no input with name " + name); } static TensorType toVespaType(ValueInfo valueInfo) { |