summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java22
1 files changed, 20 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index c1f973300d6..68bebfa6183 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -11,6 +11,7 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -25,6 +26,7 @@ import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.HashMap;
import java.util.Map;
+import java.util.Set;
import java.util.stream.Collectors;
@@ -38,7 +40,8 @@ class TensorConverter {
{
Map<String, OnnxTensor> result = new HashMap<>();
for (String name : tensorMap.keySet()) {
- Tensor vespaTensor = tensorMap.get(name);
+ Tensor vespaTensor = tensorMap.get(name);
+ name = toOnnxName(name, session.getInputInfo().keySet());
TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo());
OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env);
result.put(name, onnxTensor);
@@ -143,7 +146,22 @@ class TensorConverter {
}
static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) {
- return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo())));
+ return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()),
+ e -> toVespaType(e.getValue().getInfo())));
+ }
+
+ static String asValidName(String name) {
+ return OnnxImporter.asValidIdentifier(name);
+ }
+
+ static String toOnnxName(String name, Set<String> onnxNames) {
+ if (onnxNames.contains(name))
+ return name;
+ for (String onnxName : onnxNames) {
+ if (asValidName(onnxName).equals(name))
+ return onnxName;
+ }
+ throw new IllegalArgumentException("ONNX model has no input with name " + name);
}
static TensorType toVespaType(ValueInfo valueInfo) {