aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-01-22 14:25:03 +0100
committerLester Solbakken <lesters@oath.com>2018-01-22 14:25:03 +0100
commitf6657d14fe38dbcc431e3bfb5a5a67473b8c97a3 (patch)
treef0fbc450d695dc7286d4f4d0b15c7e4e9cb30c97 /searchlib/src/main/java
parentf37961e976bf6fb40b51c2f5bc01b7f3b2adcec5 (diff)
Support negative dimesions in import of expanddims
Diffstat (limited to 'searchlib/src/main/java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java10
1 files changed, 7 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index f77699e8f9e..f9cec4a14f0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -281,11 +281,15 @@ class OperationMapper {
if (axis.type().rank() != 0) {
throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar");
}
- int dimensionToInsert = (int)axis.asDouble();
TensorFunction inputFunction = arguments.get(0).function();
TensorType inputType = arguments.get(0).type();
+ int dimensionToInsert = (int)axis.asDouble();
+ if (dimensionToInsert < 0) {
+ dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
+ }
+
TensorType.Builder outputTypeBuilder = new TensorType.Builder();
int dimensionIndex = 0;
for (int i = 0; i < inputType.dimensions().size() + 1; ++i) {
@@ -294,8 +298,8 @@ class OperationMapper {
if (i == dimensionToInsert) {
size = 1L;
} else {
- TensorType.Dimension dimension = inputType.dimensions().get(dimensionIndex);
- size = dimension.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
+ size = dimensionSize(inputType.dimensions().get(dimensionIndex));
+ dimensionIndex++;
}
outputTypeBuilder.indexed(name, size);
}