aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
blob: 6c92ffa6055b830ebd7765eb76a6c7b1447ad952 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.tensorflow.framework.TensorProto;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;


/**
 * Converts TensorFlow tensors into Vespa tensors.
 *
 * @author bratseth
 * @author lesters
 */
public class TensorConverter {

    public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
        return toVespaTensor(tfTensor, "d");
    }

    private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
        TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
        Values values = readValuesOf(tfTensor);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
        for (int i = 0; i < values.size(); i++)
            builder.cellByDirectIndex(i, values.get(i));
        return builder.build();
    }

    static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
        Values values = readValuesOf(tfTensor);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
        for (int i = 0; i < values.size(); i++) {
            builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i));
        }
        return builder.build();
    }

    static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
        Values values = readValuesOf(tensorProto);
        for (int i = 0; i < values.size(); ++i) {
            builder.cellByDirectIndex(i, values.get(i));
        }
        return builder.build();
    }

    private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
        TensorType.Builder b = new TensorType.Builder();
        int dimensionIndex = 0;
        for (long dimensionSize : shape) {
            if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
            b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
        }
        return b.build();
    }

    public static Long tensorSize(TensorType type) {
        Long size = 1L;
        for (TensorType.Dimension dimension : type.dimensions()) {
            size *= dimensionSize(dimension);
        }
        return size;
    }

    private static Long dimensionSize(TensorType.Dimension dim) {
        return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
    }

    private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
        switch (tfTensor.dataType()) {
            case DOUBLE: return new DoubleValues(tfTensor);
            case FLOAT: return new FloatValues(tfTensor);
            case BOOL: return new BoolValues(tfTensor);
            case UINT8: return new IntValues(tfTensor);
            case INT32: return new IntValues(tfTensor);
            case INT64: return new LongValues(tfTensor);
        }
        throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
                tfTensor.dataType() + " to a Vespa tensor");
    }

    private static Values readValuesOf(TensorProto tensorProto) {
        switch (tensorProto.getDtype()) {
            case DT_BOOL:
                return new ProtoBoolValues(tensorProto);
            case DT_HALF:
                return new ProtoHalfValues(tensorProto);
            case DT_INT16:
            case DT_INT32:
                return new ProtoIntValues(tensorProto);
            case DT_INT64:
                return new ProtoInt64Values(tensorProto);
            case DT_FLOAT:
                return new ProtoFloatValues(tensorProto);
            case DT_DOUBLE:
                return new ProtoDoubleValues(tensorProto);
        }
        throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
    }

    /** Allows reading values from buffers of various numeric types as bytes */
    private static abstract class Values {
        abstract double get(int i);
        abstract int size();
    }

    private static abstract class TensorFlowValues extends Values {
        private final int size;
        TensorFlowValues(int size) {
            this.size = size;
        }
        @Override int size() { return this.size; }
    }

    private static class DoubleValues extends TensorFlowValues {
        private final DoubleBuffer values;
        DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            values = DoubleBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(values);
        }
        @Override double get(int i) {
            return values.get(i);
        }
    }

    private static class FloatValues extends TensorFlowValues {
        private final FloatBuffer values;
        FloatValues(org.tensorflow.Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            values = FloatBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(values);
        }
        @Override double get(int i) {
            return values.get(i);
        }
    }

    private static class BoolValues extends TensorFlowValues {
        private final ByteBuffer values;
        BoolValues(org.tensorflow.Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            values = ByteBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(values);
        }
        @Override double get(int i) {
            return values.get(i);
        }
    }

    private static class IntValues extends TensorFlowValues {
        private final IntBuffer values;
        IntValues(org.tensorflow.Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            values = IntBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(values);
        }
        @Override double get(int i) {
            return values.get(i);
        }
    }

    private static class LongValues extends TensorFlowValues {
        private final LongBuffer values;
        LongValues(org.tensorflow.Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            values = LongBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(values);
        }
        @Override double get(int i) {
            return values.get(i);
        }
    }

    private static abstract class ProtoValues extends Values {
        final TensorProto tensorProto;
        ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
    }

    private static class ProtoBoolValues extends ProtoValues {
        ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); }
        @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
        @Override int size() { return tensorProto.getBoolValCount(); }
    }

    private static class ProtoHalfValues extends ProtoValues {
        ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); }
        @Override double get(int i) { return tensorProto.getHalfVal(i); }
        @Override int size() { return tensorProto.getHalfValCount(); }
    }

    private static class ProtoIntValues extends ProtoValues {
        ProtoIntValues(TensorProto tensorProto) { super(tensorProto); }
        @Override double get(int i) { return tensorProto.getIntVal(i); }
        @Override int size() { return tensorProto.getIntValCount(); }
    }

    private static class ProtoInt64Values extends ProtoValues {
        ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); }
        @Override double get(int i) { return tensorProto.getInt64Val(i); }
        @Override int size() { return tensorProto.getInt64ValCount(); }
    }

    private static class ProtoFloatValues extends ProtoValues {
        ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); }
        @Override double get(int i) { return tensorProto.getFloatVal(i); }
        @Override int size() { return tensorProto.getFloatValCount(); }
    }

    private static class ProtoDoubleValues extends ProtoValues {
        ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); }
        @Override double get(int i) { return tensorProto.getDoubleVal(i); }
        @Override int size() { return tensorProto.getDoubleValCount(); }
    }

}