diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java | 238 |
1 files changed, 0 insertions, 238 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java deleted file mode 100644 index 95727acb5b4..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer.tensorflow; - -import ai.vespa.rankingexpression.importer.OrderedTensorType; -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.tensorflow.framework.DataType; -import org.tensorflow.framework.TensorProto; -import org.tensorflow.framework.TensorShapeProto; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import java.util.List; - -/** - * Converts TensorFlow tensors into Vespa tensors. - * - * @author bratseth - * @author lesters - */ -public class TensorConverter { - - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { - return toVespaTensor(tfTensor, "d"); - } - - private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix); - Values values = readValuesOf(tfTensor); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cellByDirectIndex(i, values.get(i)); - return builder.build(); - } - - static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { - Values values = readValuesOf(tfTensor); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); - for (int i = 0; i < values.size(); i++) { - builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i)); - } - return builder.build(); - } - - static Tensor toVespaTensor(TensorProto tensorProto, OrderedTensorType type) { - Values values = readValuesOf(tensorProto); - if (values.size() == 0) { // Might be stored as "tensor_content" instead - return toVespaTensor(readTensorContentOf(tensorProto), type); - } - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type.type()); - for (int i = 0; i < values.size(); ++i) - builder.cellByDirectIndex(i, values.get(i)); - return builder.build(); - } - - public static Long tensorSize(TensorType type) { - Long size = 1L; - for (TensorType.Dimension dimension : type.dimensions()) { - size *= dimensionSize(dimension); - } - return size; - } - - private static Long dimensionSize(TensorType.Dimension dim) { - return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); - } - - private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { - switch (tfTensor.dataType()) { - case DOUBLE: return new DoubleValues(tfTensor); - case FLOAT: return new FloatValues(tfTensor); - case BOOL: return new BoolValues(tfTensor); - case UINT8: return new IntValues(tfTensor); - case INT32: return new IntValues(tfTensor); - case INT64: return new LongValues(tfTensor); - default: throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); - } - } - - private static Values readValuesOf(TensorProto tensorProto) { - switch (tensorProto.getDtype()) { - case DT_BOOL: return new ProtoBoolValues(tensorProto); - case DT_HALF: return new ProtoHalfValues(tensorProto); - case DT_INT16: case DT_INT32: return new ProtoIntValues(tensorProto); - case DT_INT64: return new ProtoInt64Values(tensorProto); - case DT_FLOAT: return new ProtoFloatValues(tensorProto); - case DT_DOUBLE: return new ProtoDoubleValues(tensorProto); - default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); - } - } - - private static Class dataTypeToClass(DataType dataType) { - switch (dataType) { - case DT_BOOL: return Boolean.class; - case DT_INT16: return Short.class; - case DT_INT32: return Integer.class; - case DT_INT64: return Long.class; - case DT_HALF: return Float.class; - case DT_FLOAT: return Float.class; - case DT_DOUBLE: return Double.class; - default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); - } - } - - private static org.tensorflow.Tensor readTensorContentOf(TensorProto tensorProto) { - return org.tensorflow.Tensor.create(dataTypeToClass(tensorProto.getDtype()), - asSizeArray(tensorProto.getTensorShape().getDimList()), - tensorProto.getTensorContent().asReadOnlyByteBuffer()); - } - - private static long[] asSizeArray(List<TensorShapeProto.Dim> dimensions) { - long[] sizes = new long[dimensions.size()]; - for (int i = 0; i < dimensions.size(); i++) - sizes[i] = dimensions.get(i).getSize(); - return sizes; - } - - /** Allows reading values from buffers of various numeric types as bytes */ - private static abstract class Values { - abstract double get(int i); - abstract int size(); - } - - private static abstract class TensorFlowValues extends Values { - private final int size; - TensorFlowValues(int size) { - this.size = size; - } - @Override int size() { return this.size; } - } - - private static class DoubleValues extends TensorFlowValues { - private final DoubleBuffer values; - DoubleValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = DoubleBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class FloatValues extends TensorFlowValues { - private final FloatBuffer values; - FloatValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = FloatBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class BoolValues extends TensorFlowValues { - private final ByteBuffer values; - BoolValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = ByteBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class IntValues extends TensorFlowValues { - private final IntBuffer values; - IntValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = IntBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class LongValues extends TensorFlowValues { - private final LongBuffer values; - LongValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = LongBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static abstract class ProtoValues extends Values { - final TensorProto tensorProto; - ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } - } - - private static class ProtoBoolValues extends ProtoValues { - ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } - @Override int size() { return tensorProto.getBoolValCount(); } - } - - private static class ProtoHalfValues extends ProtoValues { - ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getHalfVal(i); } - @Override int size() { return tensorProto.getHalfValCount(); } - } - - private static class ProtoIntValues extends ProtoValues { - ProtoIntValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getIntVal(i); } - @Override int size() { return tensorProto.getIntValCount(); } - } - - private static class ProtoInt64Values extends ProtoValues { - ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getInt64Val(i); } - @Override int size() { return tensorProto.getInt64ValCount(); } - } - - private static class ProtoFloatValues extends ProtoValues { - ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getFloatVal(i); } - @Override int size() { return tensorProto.getFloatValCount(); } - } - - private static class ProtoDoubleValues extends ProtoValues { - ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getDoubleVal(i); } - @Override int size() { return tensorProto.getDoubleValCount(); } - } - -} |