summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-10-25 22:04:16 +0100
committerLester Solbakken <lesters@oath.com>2020-10-25 22:04:16 +0100
commit239f2137a05709a83163a8cae68bc04d5b0a27e8 (patch)
tree0cc9b306024b51b4ca18e721877179afc99b1cdc
parent38e2a6a325db457456e04ce8385f23b12a5da54d (diff)
Properly handle ONNX dimensions of size -1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java63
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_unbound_model.py12
-rw-r--r--config-model/src/test/integration/onnx-model/files/unbound_model.onnx11
-rw-r--r--config-model/src/test/integration/onnx-model/searchdefinitions/test.sd16
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java19
5 files changed, 91 insertions, 30 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 58213186f78..5e8b8579ee6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -150,54 +150,64 @@ public class OnnxModel {
if (onnxOutputType == null) {
throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
}
- if (containsSymbolicDimensionSizes(onnxOutputType)) {
- return getTensorTypeWithSymbolicDimensions(onnxOutputType, context);
+ if (allDimensionSizesAreKnown(onnxOutputType)) {
+ return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
}
- return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
+ return getTensorTypeWithUnknownDimensions(onnxOutputType, context);
}
- private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
- Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context);
- if (symbolicSizes.isEmpty()) {
- return TensorType.empty; // Context is probably a rank profile not using this ONNX model
- }
- return typeFrom(onnxOutputType, symbolicSizes);
+ private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) {
+ return type.getTensorType().getShape().getDimList().stream().noneMatch(d ->
+ (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1);
}
- private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) {
+ private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
+ long unboundSize = 0;
Map<String, Long> symbolicSizes = new HashMap<>();
- for (String onnxInputName : inputTypes.keySet()) {
+ for (String onnxInputName : inputTypes.keySet()) {
Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
- if ( ! containsSymbolicDimensionSizes(onnxType)) {
+ if (allDimensionSizesAreKnown(onnxType)) {
continue;
}
Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
if (vespaType.isEmpty()) {
- return Collections.emptyMap();
+ return TensorType.empty;
}
var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
var vespaDimensions = vespaType.get().dimensions();
if (vespaDimensions.size() != onnxDimensions.size()) {
- return Collections.emptyMap();
+ return TensorType.empty;
}
for (int i = 0; i < vespaDimensions.size(); ++i) {
- if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) {
+ if (vespaDimensions.get(i).size().isEmpty()) {
continue;
}
- String symbolicName = onnxDimensions.get(i).getDimParam();
Long size = vespaDimensions.get(i).size().get();
- if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
- throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " +
- "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
+
+ // Handle dimensions with size -1 - typically batch dimensions
+ if (onnxDimensions.get(i).getDimValue() == -1) {
+ if (unboundSize != 0 && unboundSize != size) {
+ throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " +
+ "for type '" + onnxOutputType + "' in ONNX model '" + name + "'");
+ }
+ unboundSize = size;
+
+ // Handle dimensions with symbolic names
+ } else if (onnxDimensions.get(i).hasDimParam()) {
+ String symbolicName = onnxDimensions.get(i).getDimParam();
+ if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
+ throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" +
+ symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
+ }
+ symbolicSizes.put(symbolicName, size);
}
- symbolicSizes.put(symbolicName, size);
}
}
- return symbolicSizes;
+ return typeFrom(onnxOutputType, symbolicSizes, unboundSize);
}
private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
@@ -217,15 +227,11 @@ public class OnnxModel {
return Optional.empty(); // if this context does not contain this input
}
- private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) {
- return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue());
- }
-
private static TensorType typeFrom(Onnx.TypeProto type) {
- return typeFrom(type, null);
+ return typeFrom(type, null, 0);
}
- private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) {
+ private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) {
String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
@@ -244,6 +250,9 @@ public class OnnxModel {
onnxDimensionSize = unknownSizes.iterator().next();
}
}
+ if (onnxDimensionSize < 0) {
+ onnxDimensionSize = unboundSize;
+ }
if (onnxDimensionSize <= 0) {
throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
"ONNX type: " + type + " to Vespa tensor type.");
diff --git a/config-model/src/test/integration/onnx-model/files/create_unbound_model.py b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py
new file mode 100755
index 00000000000..abf733ea43f
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/create_unbound_model.py
@@ -0,0 +1,12 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, [-1, 2])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [-1, 2])
+
+nodes = [helper.make_node('Identity', ['input'], ['output'])]
+graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT])
+model_def = helper.make_model(graph_def, producer_name='create_unbound_model.py')
+onnx.save(model_def, 'unbound_model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/unbound_model.onnx b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx
new file mode 100644
index 00000000000..155b3125256
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/unbound_model.onnx
@@ -0,0 +1,11 @@
+create_unbound_model.py:p
+
+inputoutput"Identitysimple_scoringZ
+input
+
+ ÿÿÿÿÿÿÿÿÿ
+b!
+output
+
+ ÿÿÿÿÿÿÿÿÿ
+B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
index 6e9ba356293..a87222e77ee 100644
--- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
+++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
@@ -35,6 +35,12 @@ search test {
output output: my_output
}
+ onnx-model unbound_model {
+ file: files/unbound_model.onnx
+ input input: my_function
+ output output: my_output
+ }
+
rank-profile test_model_config {
function my_function() {
expression: tensor(d0[2])(1)
@@ -93,4 +99,14 @@ search test {
}
}
+ rank-profile test_unbound_model {
+ function my_function() {
+ expression: tensor(d0[1],d1[2])(d1)
+ }
+ first-phase {
+ expression: onnxModel(unbound_model){d0:0,d1:1}
+ }
+ }
+
+
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index 5060aafb55f..4eb8681c374 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -25,7 +25,7 @@ public class RankingExpressionWithOnnxModelTestCase {
OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
((OnnxModelsConfig.Producer) db).getConfig(builder);
OnnxModelsConfig config = new OnnxModelsConfig(builder);
- assertEquals(5, config.model().size());
+ assertEquals(6, config.model().size());
assertEquals("my_model", config.model(0).name());
assertEquals(3, config.model(0).input().size());
@@ -62,17 +62,24 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals(3, config.model(3).input().size());
assertEquals(3, config.model(3).output().size());
- assertEquals("dynamic_model", config.model(4).name());
+ assertEquals("dynamic_model", config.model(5).name());
+ assertEquals(1, config.model(5).input().size());
+ assertEquals(1, config.model(5).output().size());
+ assertEquals("rankingExpression(my_function)", config.model(5).input(0).source());
+
+ assertEquals("unbound_model", config.model(4).name());
assertEquals(1, config.model(4).input().size());
assertEquals(1, config.model(4).output().size());
assertEquals("rankingExpression(my_function)", config.model(4).input(0).source());
+
+
}
private void assertTransformedFeature(DocumentDatabase db) {
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);
- assertEquals(7, config.rankprofile().size());
+ assertEquals(8, config.rankprofile().size());
assertEquals("test_model_config", config.rankprofile(2).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name());
@@ -109,6 +116,12 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name());
assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value());
+ assertEquals("test_unbound_model", config.rankprofile(7).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(0).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(7).fef().property(3).name());
+ assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(7).fef().property(3).value());
+
+
}
}