diff options
author | Lester Solbakken <lesters@oath.com> | 2018-01-22 14:46:14 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-01-22 14:46:14 +0100 |
commit | 95efa074a0382ae6dcb6ebaf2256a9f5915e8a28 (patch) | |
tree | 8a5a26c0b97eb13661a7b9d93b158afba06e203a /searchlib | |
parent | f6657d14fe38dbcc431e3bfb5a5a67473b8c97a3 (diff) |
Support negative values dimension sizes in tensorflow reshape import
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index f9cec4a14f0..2581cd46286 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -35,6 +35,7 @@ import java.util.List; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.stream.Collectors; +import java.util.stream.StreamSupport; /** * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. @@ -319,10 +320,12 @@ class OperationMapper { for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); int size = cell.getValue().intValue(); + if (size < 0) { + size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue(); + } outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size); dimensionIndex++; } - return reshape(inputFunction, inputType, outputTypeBuilder.build()); } @@ -403,7 +406,6 @@ class OperationMapper { @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } @Override public String toString() { return "f(a,b)(a * (1-b))"; } }); - TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add()); return new TypedTensorFunction(x.type(), outputFunction); } |