diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index 6c92ffa6055..a4fe38cce95 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import org.tensorflow.DataType; import org.tensorflow.framework.TensorProto; import java.nio.ByteBuffer; @@ -27,7 +28,7 @@ public class TensorConverter { } private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); + TensorType type = toVespaTensorType(tfTensor, dimensionPrefix); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); for (int i = 0; i < values.size(); i++) @@ -53,10 +54,10 @@ public class TensorConverter { return builder.build(); } - private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) { - TensorType.Builder b = new TensorType.Builder(); + private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType())); int dimensionIndex = 0; - for (long dimensionSize : shape) { + for (long dimensionSize : tfTensor.shape()) { if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); } @@ -85,7 +86,7 @@ public class TensorConverter { case INT64: return new LongValues(tfTensor); } throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); + tfTensor.dataType() + " to a Vespa tensor"); } private static Values readValuesOf(TensorProto tensorProto) { @@ -107,6 +108,21 @@ public class TensorConverter { throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); } + /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */ + static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + /** Allows reading values from buffers of various numeric types as bytes */ private static abstract class Values { abstract double get(int i); |