summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
blob: df43225c3332035eff8e5fe091051f069c08efe1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;

import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;

/**
 * @author bratseth
 */
public class TensorConverter {

    public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
        TensorType type = toVespaTensorType(tfTensor.shape());
        Values values = readValuesOf(tfTensor);
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
        for (int i = 0; i < values.size(); i++)
            builder.cellByDirectIndex(i, values.get(i));
        return builder.build();
    }

    private TensorType toVespaTensorType(long[] shape) {
        TensorType.Builder b = new TensorType.Builder();
        int dimensionIndex = 0;
        for (long dimensionSize : shape) {
            if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
            b.indexed("d" + (dimensionIndex++), (int) dimensionSize);
        }
        return b.build();
    }

    private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
        switch (tfTensor.dataType()) {
            case DOUBLE: return new DoubleValues(tfTensor);
            case FLOAT: return new FloatValues(tfTensor);
            // TODO: The rest
            default:
                throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
                                                   tfTensor.dataType() + " to a Vespa tensor");
        }
    }

    /** Allows reading values from buffers of various numeric types as bytes */
    private static abstract class Values {

        private final int size;

        protected Values(int size) {
            this.size = size;
        }

        abstract double get(int i);

        int size() { return size; }

    }

    private static class DoubleValues extends Values {

        private final DoubleBuffer values;

        DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            values = DoubleBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(values);
        }

        @Override
        double get(int i) {
            return values.get(i);
        }

    }

    private static class FloatValues extends Values {

        private final FloatBuffer values;

        FloatValues(org.tensorflow.Tensor<?> tfTensor) {
            super(tfTensor.numElements());
            values = FloatBuffer.allocate(tfTensor.numElements());
            tfTensor.writeTo(values);
        }

        @Override
        double get(int i) {
            return values.get(i);
        }

    }

}