diff options
author | Lester Solbakken <lesters@oath.com> | 2018-01-22 14:25:03 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-01-22 14:25:03 +0100 |
commit | f6657d14fe38dbcc431e3bfb5a5a67473b8c97a3 (patch) | |
tree | f0fbc450d695dc7286d4f4d0b15c7e4e9cb30c97 /searchlib/src/main/java | |
parent | f37961e976bf6fb40b51c2f5bc01b7f3b2adcec5 (diff) |
Support negative dimesions in import of expanddims
Diffstat (limited to 'searchlib/src/main/java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index f77699e8f9e..f9cec4a14f0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -281,11 +281,15 @@ class OperationMapper { if (axis.type().rank() != 0) { throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar"); } - int dimensionToInsert = (int)axis.asDouble(); TensorFunction inputFunction = arguments.get(0).function(); TensorType inputType = arguments.get(0).type(); + int dimensionToInsert = (int)axis.asDouble(); + if (dimensionToInsert < 0) { + dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + } + TensorType.Builder outputTypeBuilder = new TensorType.Builder(); int dimensionIndex = 0; for (int i = 0; i < inputType.dimensions().size() + 1; ++i) { @@ -294,8 +298,8 @@ class OperationMapper { if (i == dimensionToInsert) { size = 1L; } else { - TensorType.Dimension dimension = inputType.dimensions().get(dimensionIndex); - size = dimension.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); + size = dimensionSize(inputType.dimensions().get(dimensionIndex)); + dimensionIndex++; } outputTypeBuilder.indexed(name, size); } |