blob: df43225c3332035eff8e5fe091051f069c08efe1 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
/**
* @author bratseth
*/
public class TensorConverter {
public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
TensorType type = toVespaTensorType(tfTensor.shape());
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)
builder.cellByDirectIndex(i, values.get(i));
return builder.build();
}
private TensorType toVespaTensorType(long[] shape) {
TensorType.Builder b = new TensorType.Builder();
int dimensionIndex = 0;
for (long dimensionSize : shape) {
if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
b.indexed("d" + (dimensionIndex++), (int) dimensionSize);
}
return b.build();
}
private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
switch (tfTensor.dataType()) {
case DOUBLE: return new DoubleValues(tfTensor);
case FLOAT: return new FloatValues(tfTensor);
// TODO: The rest
default:
throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
tfTensor.dataType() + " to a Vespa tensor");
}
}
/** Allows reading values from buffers of various numeric types as bytes */
private static abstract class Values {
private final int size;
protected Values(int size) {
this.size = size;
}
abstract double get(int i);
int size() { return size; }
}
private static class DoubleValues extends Values {
private final DoubleBuffer values;
DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
super(tfTensor.numElements());
values = DoubleBuffer.allocate(tfTensor.numElements());
tfTensor.writeTo(values);
}
@Override
double get(int i) {
return values.get(i);
}
}
private static class FloatValues extends Values {
private final FloatBuffer values;
FloatValues(org.tensorflow.Tensor<?> tfTensor) {
super(tfTensor.numElements());
values = FloatBuffer.allocate(tfTensor.numElements());
tfTensor.writeTo(values);
}
@Override
double get(int i) {
return values.get(i);
}
}
}
|