summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-01-22 14:46:14 +0100
committerLester Solbakken <lesters@oath.com>2018-01-22 14:46:14 +0100
commit95efa074a0382ae6dcb6ebaf2256a9f5915e8a28 (patch)
tree8a5a26c0b97eb13661a7b9d93b158afba06e203a /searchlib
parentf6657d14fe38dbcc431e3bfb5a5a67473b8c97a3 (diff)
Support negative values dimension sizes in tensorflow reshape import
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java6
1 files changed, 4 insertions, 2 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index f9cec4a14f0..2581cd46286 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -35,6 +35,7 @@ import java.util.List;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
/**
* Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
@@ -319,10 +320,12 @@ class OperationMapper {
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
int size = cell.getValue().intValue();
+ if (size < 0) {
+ size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue();
+ }
outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size);
dimensionIndex++;
}
-
return reshape(inputFunction, inputType, outputTypeBuilder.build());
}
@@ -403,7 +406,6 @@ class OperationMapper {
@Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
@Override public String toString() { return "f(a,b)(a * (1-b))"; }
});
-
TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add());
return new TypedTensorFunction(x.type(), outputFunction);
}