diff options
author | Lester Solbakken <lesters@oath.com> | 2020-10-25 22:04:16 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-10-25 22:04:16 +0100 |
commit | 239f2137a05709a83163a8cae68bc04d5b0a27e8 (patch) | |
tree | 0cc9b306024b51b4ca18e721877179afc99b1cdc | |
parent | 38e2a6a325db457456e04ce8385f23b12a5da54d (diff) |
Properly handle ONNX dimensions of size -1
5 files changed, 91 insertions, 30 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 58213186f78..5e8b8579ee6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -150,54 +150,64 @@ public class OnnxModel { if (onnxOutputType == null) { throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); } - if (containsSymbolicDimensionSizes(onnxOutputType)) { - return getTensorTypeWithSymbolicDimensions(onnxOutputType, context); + if (allDimensionSizesAreKnown(onnxOutputType)) { + return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); } - return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); + return getTensorTypeWithUnknownDimensions(onnxOutputType, context); } - private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { - Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context); - if (symbolicSizes.isEmpty()) { - return TensorType.empty; // Context is probably a rank profile not using this ONNX model - } - return typeFrom(onnxOutputType, symbolicSizes); + private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) { + return type.getTensorType().getShape().getDimList().stream().noneMatch(d -> + (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1); } - private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) { + private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { + long unboundSize = 0; Map<String, Long> symbolicSizes = new HashMap<>(); - for (String onnxInputName : inputTypes.keySet()) { + for (String onnxInputName : inputTypes.keySet()) { Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); - if ( ! containsSymbolicDimensionSizes(onnxType)) { + if (allDimensionSizesAreKnown(onnxType)) { continue; } Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); if (vespaType.isEmpty()) { - return Collections.emptyMap(); + return TensorType.empty; } var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); var vespaDimensions = vespaType.get().dimensions(); if (vespaDimensions.size() != onnxDimensions.size()) { - return Collections.emptyMap(); + return TensorType.empty; } for (int i = 0; i < vespaDimensions.size(); ++i) { - if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) { + if (vespaDimensions.get(i).size().isEmpty()) { continue; } - String symbolicName = onnxDimensions.get(i).getDimParam(); Long size = vespaDimensions.get(i).size().get(); - if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { - throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " + - "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); + + // Handle dimensions with size -1 - typically batch dimensions + if (onnxDimensions.get(i).getDimValue() == -1) { + if (unboundSize != 0 && unboundSize != size) { + throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " + + "for type '" + onnxOutputType + "' in ONNX model '" + name + "'"); + } + unboundSize = size; + + // Handle dimensions with symbolic names + } else if (onnxDimensions.get(i).hasDimParam()) { + String symbolicName = onnxDimensions.get(i).getDimParam(); + if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { + throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); + } + symbolicSizes.put(symbolicName, size); } - symbolicSizes.put(symbolicName, size); } } - return symbolicSizes; + return typeFrom(onnxOutputType, symbolicSizes, unboundSize); } private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { @@ -217,15 +227,11 @@ public class OnnxModel { return Optional.empty(); // if this context does not contain this input } - private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) { - return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue()); - } - private static TensorType typeFrom(Onnx.TypeProto type) { - return typeFrom(type, null); + return typeFrom(type, null, 0); } - private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) { + private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) { String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... Onnx.TensorShapeProto shape = type.getTensorType().getShape(); TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); @@ -244,6 +250,9 @@ public class OnnxModel { onnxDimensionSize = unknownSizes.iterator().next(); } } + if (onnxDimensionSize < 0) { + onnxDimensionSize = unboundSize; + } if (onnxDimensionSize <= 0) { throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + "ONNX type: " + type + " to Vespa tensor type."); diff --git a/config-model/src/test/integration/onnx-model/files/create_unbound_model.py b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py new file mode 100755 index 00000000000..abf733ea43f --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py @@ -0,0 +1,12 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, [-1, 2]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [-1, 2]) + +nodes = [helper.make_node('Identity', ['input'], ['output'])] +graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT]) +model_def = helper.make_model(graph_def, producer_name='create_unbound_model.py') +onnx.save(model_def, 'unbound_model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/unbound_model.onnx b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx new file mode 100644 index 00000000000..155b3125256 --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx @@ -0,0 +1,11 @@ +create_unbound_model.py:p + +inputoutput"Identitysimple_scoringZ +input + +ÿÿÿÿÿÿÿÿÿ +b! +output + +ÿÿÿÿÿÿÿÿÿ +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd index 6e9ba356293..a87222e77ee 100644 --- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd +++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd @@ -35,6 +35,12 @@ search test { output output: my_output } + onnx-model unbound_model { + file: files/unbound_model.onnx + input input: my_function + output output: my_output + } + rank-profile test_model_config { function my_function() { expression: tensor(d0[2])(1) @@ -93,4 +99,14 @@ search test { } } + rank-profile test_unbound_model { + function my_function() { + expression: tensor(d0[1],d1[2])(d1) + } + first-phase { + expression: onnxModel(unbound_model){d0:0,d1:1} + } + } + + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index 5060aafb55f..4eb8681c374 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -25,7 +25,7 @@ public class RankingExpressionWithOnnxModelTestCase { OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); - assertEquals(5, config.model().size()); + assertEquals(6, config.model().size()); assertEquals("my_model", config.model(0).name()); assertEquals(3, config.model(0).input().size()); @@ -62,17 +62,24 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals(3, config.model(3).input().size()); assertEquals(3, config.model(3).output().size()); - assertEquals("dynamic_model", config.model(4).name()); + assertEquals("dynamic_model", config.model(5).name()); + assertEquals(1, config.model(5).input().size()); + assertEquals(1, config.model(5).output().size()); + assertEquals("rankingExpression(my_function)", config.model(5).input(0).source()); + + assertEquals("unbound_model", config.model(4).name()); assertEquals(1, config.model(4).input().size()); assertEquals(1, config.model(4).output().size()); assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); + + } private void assertTransformedFeature(DocumentDatabase db) { RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(7, config.rankprofile().size()); + assertEquals(8, config.rankprofile().size()); assertEquals("test_model_config", config.rankprofile(2).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); @@ -109,6 +116,12 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); + assertEquals("test_unbound_model", config.rankprofile(7).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(7).fef().property(3).name()); + assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(7).fef().property(3).value()); + + } } |