summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 11:14:35 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 11:14:35 +0100
commit69e4f6bf072d8ebfb12761c450f2bdacf86e226c (patch)
tree9c944d4ebced2489b40494c4d0a215a072126c5b /model-integration
parent7c7a5e8475a3a9221fa9b308c3aae14ca27e550e (diff)
Convert onnx dimensions of size 0
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java8
1 files changed, 5 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 98ff8ca735f..7c8038cea66 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -23,7 +23,8 @@ class TypeConverter {
int vespaIndex = type.dimensionMap(onnxIndex);
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
- if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
+ long onnxDimensionSize = onnxDimension.getDimValue() == 0 ? 1 : onnxDimension.getDimValue();
+ if (onnxDimensionSize != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
}
}
@@ -37,8 +38,9 @@ class TypeConverter {
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- if (onnxDimension.getDimValue() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ long onnxDimensionSize = onnxDimension.getDimValue() == 0 ? 1 : onnxDimension.getDimValue();
+ if (onnxDimensionSize >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimensionSize));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}