diff options
author | Lester Solbakken <lesters@oath.com> | 2019-11-22 11:14:35 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-11-22 11:14:35 +0100 |
commit | 69e4f6bf072d8ebfb12761c450f2bdacf86e226c (patch) | |
tree | 9c944d4ebced2489b40494c4d0a215a072126c5b /model-integration | |
parent | 7c7a5e8475a3a9221fa9b308c3aae14ca27e550e (diff) |
Convert onnx dimensions of size 0
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 98ff8ca735f..7c8038cea66 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -23,7 +23,8 @@ class TypeConverter { int vespaIndex = type.dimensionMap(onnxIndex); Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); - if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { + long onnxDimensionSize = onnxDimension.getDimValue() == 0 ? 1 : onnxDimension.getDimValue(); + if (onnxDimensionSize != vespaDimension.size().orElse(-1L)) { throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions"); } } @@ -37,8 +38,9 @@ class TypeConverter { for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - if (onnxDimension.getDimValue() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); + long onnxDimensionSize = onnxDimension.getDimValue() == 0 ? 1 : onnxDimension.getDimValue(); + if (onnxDimensionSize >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimensionSize)); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } |