aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java26
1 files changed, 21 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index 6c92ffa6055..a4fe38cce95 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import org.tensorflow.DataType;
import org.tensorflow.framework.TensorProto;
import java.nio.ByteBuffer;
@@ -27,7 +28,7 @@ public class TensorConverter {
}
private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
+ TensorType type = toVespaTensorType(tfTensor, dimensionPrefix);
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)
@@ -53,10 +54,10 @@ public class TensorConverter {
return builder.build();
}
- private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
- TensorType.Builder b = new TensorType.Builder();
+ private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType()));
int dimensionIndex = 0;
- for (long dimensionSize : shape) {
+ for (long dimensionSize : tfTensor.shape()) {
if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
}
@@ -85,7 +86,7 @@ public class TensorConverter {
case INT64: return new LongValues(tfTensor);
}
throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
+ tfTensor.dataType() + " to a Vespa tensor");
}
private static Values readValuesOf(TensorProto tensorProto) {
@@ -107,6 +108,21 @@ public class TensorConverter {
throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
}
+ /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */
+ static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
/** Allows reading values from buffers of various numeric types as bytes */
private static abstract class Values {
abstract double get(int i);